Merge pull request #19779 from Huizerd:master

PiperOrigin-RevId: 205266716
This commit is contained in:
TensorFlower Gardener 2018-07-19 11:01:42 -07:00
commit 9fa89160a4
2 changed files with 10 additions and 9 deletions

View File

@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
### Adding pruning ops to the training graph
The final step involves adding ops to the training graph that monitors the
distribution of the layer's weight magnitudes and determines the layer threshold
such masking all the weights below this threshold achieves the sparsity level
desired for the current training step. This can be achieved as follows:
The final step involves adding ops to the training graph that monitor the
distribution of the layer's weight magnitudes and determine the layer threshold,
such that masking all the weights below this threshold achieves the sparsity
level desired for the current training step. This can be achieved as follows:
```python
tf.app.flags.DEFINE_string(
@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
with tf.graph.as_default():
# Create global step variable
global_step = tf.train.get_global_step()
global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
@ -103,6 +103,7 @@ with tf.graph.as_default():
mon_sess.run(mask_update_op)
```
Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
## Example: Pruning and training deep CNNs on the cifar10 dataset

View File

@ -518,11 +518,11 @@ class Pruning(object):
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
for index, mask in enumerate(masks):
for mask, threshold in zip(masks, thresholds):
if not self._exists_in_do_not_prune_list(mask.name):
summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
summary.scalar(thresholds[index].op.name + '/threshold',
thresholds[index])
summary.scalar(mask.op.name + '/sparsity',
nn_impl.zero_fraction(mask))
summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())