Merge pull request #19779 from Huizerd:master
PiperOrigin-RevId: 205266716
This commit is contained in:
commit
9fa89160a4
@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
|
|||||||
sparsity_function_exponent is set to 3.
|
sparsity_function_exponent is set to 3.
|
||||||
### Adding pruning ops to the training graph
|
### Adding pruning ops to the training graph
|
||||||
|
|
||||||
The final step involves adding ops to the training graph that monitors the
|
The final step involves adding ops to the training graph that monitor the
|
||||||
distribution of the layer's weight magnitudes and determines the layer threshold
|
distribution of the layer's weight magnitudes and determine the layer threshold,
|
||||||
such masking all the weights below this threshold achieves the sparsity level
|
such that masking all the weights below this threshold achieves the sparsity
|
||||||
desired for the current training step. This can be achieved as follows:
|
level desired for the current training step. This can be achieved as follows:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tf.app.flags.DEFINE_string(
|
tf.app.flags.DEFINE_string(
|
||||||
@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
|
|||||||
with tf.graph.as_default():
|
with tf.graph.as_default():
|
||||||
|
|
||||||
# Create global step variable
|
# Create global step variable
|
||||||
global_step = tf.train.get_global_step()
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
|
||||||
# Parse pruning hyperparameters
|
# Parse pruning hyperparameters
|
||||||
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
|
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
|
||||||
@ -103,6 +103,7 @@ with tf.graph.as_default():
|
|||||||
mon_sess.run(mask_update_op)
|
mon_sess.run(mask_update_op)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
|
||||||
|
|
||||||
## Example: Pruning and training deep CNNs on the cifar10 dataset
|
## Example: Pruning and training deep CNNs on the cifar10 dataset
|
||||||
|
|
||||||
|
@ -518,11 +518,11 @@ class Pruning(object):
|
|||||||
summary.scalar('last_mask_update_step', self._last_update_step)
|
summary.scalar('last_mask_update_step', self._last_update_step)
|
||||||
masks = get_masks()
|
masks = get_masks()
|
||||||
thresholds = get_thresholds()
|
thresholds = get_thresholds()
|
||||||
for index, mask in enumerate(masks):
|
for mask, threshold in zip(masks, thresholds):
|
||||||
if not self._exists_in_do_not_prune_list(mask.name):
|
if not self._exists_in_do_not_prune_list(mask.name):
|
||||||
summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
|
summary.scalar(mask.op.name + '/sparsity',
|
||||||
summary.scalar(thresholds[index].op.name + '/threshold',
|
nn_impl.zero_fraction(mask))
|
||||||
thresholds[index])
|
summary.scalar(threshold.op.name + '/threshold', threshold)
|
||||||
|
|
||||||
def print_hparams(self):
|
def print_hparams(self):
|
||||||
logging.info(self._spec.to_json())
|
logging.info(self._spec.to_json())
|
||||||
|
Loading…
Reference in New Issue
Block a user