Merge pull request #31452 from HarikrishnanBalagopal:patch-1

PiperOrigin-RevId: 262425830
This commit is contained in:
TensorFlower Gardener 2019-08-08 14:56:37 -07:00
commit 7027d3bd95

View File

@ -138,15 +138,13 @@ class OptimizerV2(trackable.Trackable):
loss = <call_loss_function>
vars = <list_of_variables>
grads = tape.gradient(loss, vars)
# Process the gradients, for example cap them, etc.
# capped_grads = [MyCapper(g) for g in grads]
processed_grads = [process_gradient(g) for g in grads]
grads_and_vars = zip(processed_grads, var_list)
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)
# Ask the optimizer to apply the processed gradients.
opt.apply_gradients(zip(processed_grads, var_list))
```
### Use with `tf.distribute.Strategy`.