Add grad_pass_through utility to custom_gradients.
grad_pass_through will let user wrap any op (including non differentiable ops such as assign operations) and produce a new op where gradients can 'pass through' to the inputs. An example use case is defining an approximate 'differentiable' moving avarage, where in the forward pass we use the current value of the moving average, and in the backward pass we only use last value fed to the op. PiperOrigin-RevId: 257770003
This commit is contained in:
parent
6bcff21ad8
commit
cbe94fd66e
@ -3920,6 +3920,7 @@ cuda_py_test(
|
||||
":framework_test_lib",
|
||||
":functional_ops",
|
||||
":gradients",
|
||||
":init_ops",
|
||||
":list_ops",
|
||||
":math_grad",
|
||||
":math_ops",
|
||||
@ -3927,6 +3928,7 @@ cuda_py_test(
|
||||
":nn_ops",
|
||||
":platform_test",
|
||||
":state_grad",
|
||||
":state_ops",
|
||||
":tensor_array_grad",
|
||||
":tensor_array_ops",
|
||||
":test_ops",
|
||||
|
@ -74,7 +74,7 @@ def copy_handle_data(source_t, target_t):
|
||||
shapes, types = zip(*[(pair.shape, pair.dtype)
|
||||
for pair in handle_data.shape_and_type])
|
||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
||||
shapes = [[d.size for d in s.dim]
|
||||
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
|
||||
if not s.unknown_rank else None for s in shapes]
|
||||
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
target_t._op._graph._c_graph, # pylint: disable=protected-access
|
||||
@ -394,3 +394,56 @@ def recompute_grad(f):
|
||||
return result, grad
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@tf_export("grad_pass_through")
|
||||
def grad_pass_through(f):
|
||||
"""Creates a grad-pass-through op with the forward behavior provided in f.
|
||||
|
||||
Use this function to wrap any op, maintaining its behavior in the forward
|
||||
pass, but replacing the original op in the backward graph with an identity.
|
||||
For example:
|
||||
|
||||
```python
|
||||
x = tf.Variable(1.0, name="x")
|
||||
z = tf.Variable(3.0, name="z")
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
# y will evaluate to 9.0
|
||||
y = tf.grad_pass_through(x.assign)(z**2)
|
||||
# grads will evaluate to 6.0
|
||||
grads = tape.gradient(y, z)
|
||||
```
|
||||
|
||||
Another example is a 'differentiable' moving average approximation, where
|
||||
gradients are allowed to flow into the last value fed to the moving average,
|
||||
but the moving average is still used for the forward pass:
|
||||
|
||||
```python
|
||||
x = ... # Some scalar value
|
||||
# A moving average object, we don't need to know how this is implemented
|
||||
moving_average = MovingAverage()
|
||||
with backprop.GradientTape() as tape:
|
||||
# mavg_x will evaluate to the current running average value
|
||||
mavg_x = tf.grad_pass_through(moving_average)(x)
|
||||
grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
|
||||
```
|
||||
|
||||
Args:
|
||||
f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
|
||||
outputs.
|
||||
|
||||
Returns:
|
||||
A function `h(x)` which returns the same values as `f(x)` and whose
|
||||
gradients are the same as those of an identity function.
|
||||
"""
|
||||
@custom_gradient
|
||||
def _grad_pass_through_op(*args, **kwargs):
|
||||
def grad(*args, **kwargs):
|
||||
variables = kwargs.get("variables")
|
||||
if variables is not None:
|
||||
# Variables involved in the wrapped op will not receive gradients.
|
||||
return args, [None] * len(variables)
|
||||
return args
|
||||
return f(*args, **kwargs), grad
|
||||
return tf_decorator.make_decorator(f, _grad_pass_through_op)
|
||||
|
@ -44,12 +44,14 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import unconnected_gradients
|
||||
@ -1389,5 +1391,64 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(g, g_re)
|
||||
|
||||
|
||||
class GradPassThroughTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def test_gradients_v1(self):
|
||||
x = variable_scope.get_variable(
|
||||
name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
|
||||
use_resource=True)
|
||||
z = variable_scope.get_variable(
|
||||
name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
|
||||
use_resource=True)
|
||||
|
||||
# Verify that assign op is not differentiable
|
||||
y = state_ops.assign(x, z**2)
|
||||
grads = gradients.gradients(y, z)
|
||||
self.assertIsNone(grads[0])
|
||||
|
||||
# Verify that when the (non differentiable) assign op is wrapped with
|
||||
# grad_pass_through, gradients are correctly forwarded to the inputs.
|
||||
# Form an input as quadratic function of variable z and check that the
|
||||
# gradient of output wrt to z is correct.
|
||||
y = custom_gradient.grad_pass_through(
|
||||
lambda v: state_ops.assign(x, v))(z**2)
|
||||
grads = gradients.gradients(y, z)
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllClose(grads[0].eval(), 6.0)
|
||||
|
||||
# Verify that variables involved in the wrapped op do not receive gradients.
|
||||
y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
|
||||
grads = gradients.gradients(y, x)
|
||||
self.assertIsNone(grads[0])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_gradients_v2(self):
|
||||
x = variables.Variable(1.0, name="x")
|
||||
z = variables.Variable(3.0, name="z")
|
||||
|
||||
# Verify that assign op is not differentiable
|
||||
with backprop.GradientTape() as tape:
|
||||
y = x.assign(z**2)
|
||||
grads = tape.gradient(y, z)
|
||||
self.assertIsNone(grads)
|
||||
|
||||
# Verify that when the (non differentiable) assign op is wrapped with
|
||||
# grad_pass_through, gradients are correctly forwarded to the inputs.
|
||||
# Form an input as quadratic function of variable z and check that the
|
||||
# gradient of output wrt to z is correct.
|
||||
with backprop.GradientTape() as tape:
|
||||
y = custom_gradient.grad_pass_through(x.assign)(z**2)
|
||||
grads = tape.gradient(y, z)
|
||||
self.assertAllClose(grads, 6.0)
|
||||
|
||||
# Verify that variables involved in the wrapped op do not receive gradients.
|
||||
with backprop.GradientTape() as tape:
|
||||
y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
|
||||
grads = tape.gradient(y, x)
|
||||
self.assertIsNone(grads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -1388,6 +1388,10 @@ tf_module {
|
||||
name: "global_variables_initializer"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "grad_pass_through"
|
||||
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "gradients"
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
|
||||
|
@ -664,6 +664,10 @@ tf_module {
|
||||
name: "get_static_value"
|
||||
argspec: "args=[\'tensor\', \'partial\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "grad_pass_through"
|
||||
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "gradients"
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
|
||||
|
Loading…
x
Reference in New Issue
Block a user