Merge pull request #31452 from HarikrishnanBalagopal:patch-1
PiperOrigin-RevId: 262425830
This commit is contained in:
commit
7027d3bd95
@ -138,15 +138,13 @@ class OptimizerV2(trackable.Trackable):
|
||||
loss = <call_loss_function>
|
||||
vars = <list_of_variables>
|
||||
grads = tape.gradient(loss, vars)
|
||||
|
||||
# Process the gradients, for example cap them, etc.
|
||||
# capped_grads = [MyCapper(g) for g in grads]
|
||||
processed_grads = [process_gradient(g) for g in grads]
|
||||
grads_and_vars = zip(processed_grads, var_list)
|
||||
|
||||
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
|
||||
# need to the 'gradient' part, for example cap them, etc.
|
||||
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
|
||||
|
||||
# Ask the optimizer to apply the capped gradients.
|
||||
opt.apply_gradients(capped_grads_and_vars)
|
||||
# Ask the optimizer to apply the processed gradients.
|
||||
opt.apply_gradients(zip(processed_grads, var_list))
|
||||
```
|
||||
|
||||
### Use with `tf.distribute.Strategy`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user