Merge pull request #19779 from Huizerd:master
PiperOrigin-RevId: 205266716
This commit is contained in:
commit
9fa89160a4
@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
|
||||
sparsity_function_exponent is set to 3.
|
||||
### Adding pruning ops to the training graph
|
||||
|
||||
The final step involves adding ops to the training graph that monitors the
|
||||
distribution of the layer's weight magnitudes and determines the layer threshold
|
||||
such masking all the weights below this threshold achieves the sparsity level
|
||||
desired for the current training step. This can be achieved as follows:
|
||||
The final step involves adding ops to the training graph that monitor the
|
||||
distribution of the layer's weight magnitudes and determine the layer threshold,
|
||||
such that masking all the weights below this threshold achieves the sparsity
|
||||
level desired for the current training step. This can be achieved as follows:
|
||||
|
||||
```python
|
||||
tf.app.flags.DEFINE_string(
|
||||
@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
|
||||
with tf.graph.as_default():
|
||||
|
||||
# Create global step variable
|
||||
global_step = tf.train.get_global_step()
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# Parse pruning hyperparameters
|
||||
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
|
||||
@ -103,6 +103,7 @@ with tf.graph.as_default():
|
||||
mon_sess.run(mask_update_op)
|
||||
|
||||
```
|
||||
Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
|
||||
|
||||
## Example: Pruning and training deep CNNs on the cifar10 dataset
|
||||
|
||||
|
@ -518,11 +518,11 @@ class Pruning(object):
|
||||
summary.scalar('last_mask_update_step', self._last_update_step)
|
||||
masks = get_masks()
|
||||
thresholds = get_thresholds()
|
||||
for index, mask in enumerate(masks):
|
||||
for mask, threshold in zip(masks, thresholds):
|
||||
if not self._exists_in_do_not_prune_list(mask.name):
|
||||
summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
|
||||
summary.scalar(thresholds[index].op.name + '/threshold',
|
||||
thresholds[index])
|
||||
summary.scalar(mask.op.name + '/sparsity',
|
||||
nn_impl.zero_fraction(mask))
|
||||
summary.scalar(threshold.op.name + '/threshold', threshold)
|
||||
|
||||
def print_hparams(self):
|
||||
logging.info(self._spec.to_json())
|
||||
|
Loading…
Reference in New Issue
Block a user