From f2e22502fd58e8d81c9e080b9242375fbf2bc772 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 5 Jun 2018 14:35:38 +0200 Subject: [PATCH 1/5] Updated line for creating global step + grammar tf.train.get_global_step() returns None if there is no global step, preventing the pruning from working. Therefore, tf.train.get_or_create_global_step() is a safer option. --- tensorflow/contrib/model_pruning/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 86f4fd6adf6..50e7e5d7cd9 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3. ### Adding pruning ops to the training graph -The final step involves adding ops to the training graph that monitors the -distribution of the layer's weight magnitudes and determines the layer threshold -such masking all the weights below this threshold achieves the sparsity level -desired for the current training step. This can be achieved as follows: +The final step involves adding ops to the training graph that monitor the +distribution of the layer's weight magnitudes and determine the layer threshold, +such that masking all the weights below this threshold achieves the sparsity +level desired for the current training step. This can be achieved as follows: ```python tf.app.flags.DEFINE_string( @@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string( with tf.graph.as_default(): # Create global step variable - global_step = tf.train.get_global_step() + global_step = tf.train.get_or_create_global_step() # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) From f9c7fe82cb930ee26d281e4bf21211ed352d176e Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 5 Jun 2018 14:49:04 +0200 Subject: [PATCH 2/5] Put some emphasis on incrementing global step Pruning will not work if the global step is not incremented --- tensorflow/contrib/model_pruning/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 50e7e5d7cd9..9143d082bf0 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -103,6 +103,7 @@ with tf.graph.as_default(): mon_sess.run(mask_update_op) ``` +Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work! ## Example: Pruning and training deep CNNs on the cifar10 dataset From e106a458dd26db58c7d5abbd4afef60f8ce33252 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 5 Jun 2018 15:22:07 +0200 Subject: [PATCH 3/5] Prevent redundant ":0" in summary names Take identical approach as is done with thresholds: using tf.Variable.op.name instead of tf.Variable.name, to prevent TensorFlow saying summary names are illegal (due to ":") --- tensorflow/contrib/model_pruning/python/pruning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 4b7af18b331..e6f9acc1399 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -520,7 +520,7 @@ class Pruning(object): thresholds = get_thresholds() for index, mask in enumerate(masks): if not self._exists_in_do_not_prune_list(mask.name): - summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) summary.scalar(thresholds[index].op.name + '/threshold', thresholds[index]) From 90b28b7316edb644b71b01edaaa8553d5913fc19 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 6 Jun 2018 16:07:20 +0200 Subject: [PATCH 4/5] Removed redundant use of enumeration Since every mask has an accompanying threshold, zip(masks, thresholds) can be used instead of enumerate(masks) and calling thresholds by index. --- tensorflow/contrib/model_pruning/python/pruning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index e6f9acc1399..d843fa26d57 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -518,11 +518,10 @@ class Pruning(object): summary.scalar('last_mask_update_step', self._last_update_step) masks = get_masks() thresholds = get_thresholds() - for index, mask in enumerate(masks): + for mask, threshold in zip(masks, thresholds): if not self._exists_in_do_not_prune_list(mask.name): summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) - summary.scalar(thresholds[index].op.name + '/threshold', - thresholds[index]) + summary.scalar(threshold.op.name + '/threshold', threshold) def print_hparams(self): logging.info(self._spec.to_json()) From 11cd70438e7d7104904bf8f3b24fcaf6fd88eab5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 2 Jul 2018 13:37:38 -0700 Subject: [PATCH 5/5] Fix lint error. --- tensorflow/contrib/model_pruning/python/pruning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index d843fa26d57..da9d398cbc0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -520,7 +520,8 @@ class Pruning(object): thresholds = get_thresholds() for mask, threshold in zip(masks, thresholds): if not self._exists_in_do_not_prune_list(mask.name): - summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(mask.op.name + '/sparsity', + nn_impl.zero_fraction(mask)) summary.scalar(threshold.op.name + '/threshold', threshold) def print_hparams(self):