Changed example showing gradient processing

The original example was processing the gradients twice.
1st: grads_and_vars = zip(processed_grads, var_list)
2nd: capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]

The 2nd line is especially weird because it unnecessarily zips the gradients with the var_list even though it is only processing the gradient part.

Refactored the example to be clearer, now there is only a single line that processes the gradients.
This commit is contained in:
HarikrishnanBalagopal 2019-08-08 23:10:40 +05:30 committed by GitHub
parent 84dafbc8ec
commit 31eb0e012e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -138,15 +138,13 @@ class OptimizerV2(trackable.Trackable):
loss = <call_loss_function>
vars = <list_of_variables>
grads = tape.gradient(loss, vars)
# Do whatever you need to the gradients, for example cap them, etc.
# capped_grads = [MyCapper(g) for g in grads]
processed_grads = [process_gradient(g) for g in grads]
grads_and_vars = zip(processed_grads, var_list)
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)
# Ask the optimizer to apply the processed gradients.
opt.apply_gradients(zip(processed_grads, var_list))
```
### Use with `tf.distribute.Strategy`.