Merge pull request #31452 from HarikrishnanBalagopal:patch-1
PiperOrigin-RevId: 262425830
This commit is contained in:
commit
7027d3bd95
@ -138,15 +138,13 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
loss = <call_loss_function>
|
loss = <call_loss_function>
|
||||||
vars = <list_of_variables>
|
vars = <list_of_variables>
|
||||||
grads = tape.gradient(loss, vars)
|
grads = tape.gradient(loss, vars)
|
||||||
|
|
||||||
|
# Process the gradients, for example cap them, etc.
|
||||||
|
# capped_grads = [MyCapper(g) for g in grads]
|
||||||
processed_grads = [process_gradient(g) for g in grads]
|
processed_grads = [process_gradient(g) for g in grads]
|
||||||
grads_and_vars = zip(processed_grads, var_list)
|
|
||||||
|
|
||||||
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
|
# Ask the optimizer to apply the processed gradients.
|
||||||
# need to the 'gradient' part, for example cap them, etc.
|
opt.apply_gradients(zip(processed_grads, var_list))
|
||||||
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
|
|
||||||
|
|
||||||
# Ask the optimizer to apply the capped gradients.
|
|
||||||
opt.apply_gradients(capped_grads_and_vars)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use with `tf.distribute.Strategy`.
|
### Use with `tf.distribute.Strategy`.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user