Moving model_pruning library to tf.contrib

PiperOrigin-RevId: 174214419
This commit is contained in:
A. Unique TensorFlower 2017-11-01 11:55:32 -07:00 committed by TensorFlower Gardener
parent 693325c832
commit 7ece1c0b8e
20 changed files with 3793 additions and 0 deletions

View File

@ -413,6 +413,7 @@ filegroup(
"//tensorflow/contrib/makefile:all_files",
"//tensorflow/contrib/meta_graph_transform:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/model_pruning:all_files",
"//tensorflow/contrib/mpi_collectives:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nearest_neighbor:all_files",

View File

@ -57,6 +57,7 @@ py_library(
"//tensorflow/contrib/memory_stats:memory_stats_py",
"//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/model_pruning",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",

View File

@ -51,6 +51,7 @@ from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import memory_stats
from tensorflow.contrib import metrics
from tensorflow.contrib import model_pruning
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
from tensorflow.contrib import opt

View File

@ -518,6 +518,11 @@ add_python_module("tensorflow/contrib/metrics/python")
add_python_module("tensorflow/contrib/metrics/python/kernel_tests")
add_python_module("tensorflow/contrib/metrics/python/metrics")
add_python_module("tensorflow/contrib/metrics/python/ops")
add_python_module("tensorflow/contrib/model_pruning")
add_python_module("tensorflow/contrib/model_pruning/examples")
add_python_module("tensorflow/contrib/model_pruning/examples/cifar10")
add_python_module("tensorflow/contrib/model_pruning/python")
add_python_module("tensorflow/contrib/model_pruning/python/layers")
add_python_module("tensorflow/contrib/ndlstm")
add_python_module("tensorflow/contrib/ndlstm/python")
add_python_module("tensorflow/contrib/nn")

View File

@ -0,0 +1,139 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "core_layers",
srcs = ["python/layers/core_layers.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:layers",
"//tensorflow/python:ops",
"//tensorflow/python:platform",
],
)
py_library(
name = "layers",
srcs = ["python/layers/layers.py"],
srcs_version = "PY2AND3",
deps = [
":core_layers",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/layers:layers_py",
"//third_party/py/numpy",
],
)
py_test(
name = "layers_test",
size = "small",
srcs = ["python/layers/layers_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "learning",
srcs = ["python/learning.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/slim",
],
)
py_library(
name = "rnn_cells",
srcs = ["python/layers/rnn_cells.py"],
srcs_version = "PY2AND3",
deps = [
":core_layers",
],
)
py_library(
name = "pruning",
srcs = ["python/pruning.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":core_layers",
"//tensorflow/contrib/training:training_py",
"//tensorflow/python:platform",
"//third_party/py/numpy",
],
)
py_test(
name = "pruning_test",
size = "small",
srcs = ["python/pruning_test.py"],
srcs_version = "PY2AND3",
deps = [
":pruning",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "rnn_cells_test",
size = "small",
srcs = ["python/layers/rnn_cells_test.py"],
srcs_version = "PY2AND3",
deps = [
":pruning",
":rnn_cells",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "init_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
)
# Top-level library
py_library(
name = "model_pruning",
srcs_version = "PY2AND3",
deps = [
":init_py",
":layers",
":learning",
":pruning",
":rnn_cells",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,195 @@
# Model pruning: Training tensorflow models to have masked connections
This document describes the API that facilitates magnitude-based pruning of
neural network's weight tensors. The API helps inject necessary tensorflow op
into the training graph so the model can be pruned while it is being trained.
### Model creation
The first step involves adding mask and threshold variables to the layers that
need to undergo pruning. The variable mask is the same shape as the layer's
weight tensor and determines which of the weights participate in the forward
execution of the graph. This can be achieved by wrapping the weight tensor of
the layer with the `apply_mask` function provided in
[pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py).
For example:
```python
conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding)
```
This creates a convolutional layer with additional variables mask and threshold
as shown below: ![Convolutional layer with mask and
threshold](./mask.png "Convolutional layer with mask and threshold")
Alternatively, the API also provides variant of tensorflow layers with these
auxiliary variables built-in (see
[layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers))
. Layers currently supported:
* [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83)
* [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241)
* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
### Adding pruning ops to the training graph
The pruning library allows for specification of the following hyper parameters:
| Hyperparameter | Type | Default | Description |
| ---------------------------- | ------- | ------------- | -------------- |
| name | string | model_pruning | Name of the |
: : : : pruning :
: : : : specification. :
: : : : Used for :
: : : : adding :
: : : : summaries and :
: : : : ops under a :
: : : : common :
: : : : tensorflow :
: : : : name_scope :
| begin_pruning_step | integer | 0 | The global |
: : : : step at which :
: : : : to begin :
: : : : pruning :
| end_pruning_step | integer | -1 | The global |
: : : : step at which :
: : : : to terminate :
: : : : pruning. :
: : : : Defaults to -1 :
: : : : implying that :
: : : : pruning :
: : : : continues till :
: : : : the training :
: : : : stops :
| do_not_prune | list of | [""] | list of layers |
: : strings : : that are not :
: : : : pruned :
| threshold_decay | float | 0.9 | The decay |
: : : : factor to use :
: : : : for :
: : : : exponential :
: : : : decay of the :
: : : : thresholds :
| pruning_frequency | integer | 10 | How often |
: : : : should the :
: : : : masks be :
: : : : updated? (in # :
: : : : of :
: : : : global_steps). :
| nbins | integer | 255 | Number of bins |
: : : : to use for :
: : : : histogram :
: : : : computation :
| initial_sparsity | float | 0.0 | Initial |
: : : : sparsity value :
| target_sparsity | float | 0.5 | Target |
: : : : sparsity value :
| sparsity_function_begin_step | integer | 0 | The global |
: : : : step at this :
: : : : which the :
: : : : gradual :
: : : : sparsity :
: : : : function :
: : : : begins to take :
: : : : effect :
| sparsity_function_end_step | integer | 100 | The global |
: : : : step used as :
: : : : the end point :
: : : : for the :
: : : : gradual :
: : : : sparsity :
: : : : function :
| sparsity_function_exponent | float | 3.0 | exponent = 1 |
: : : : is linearly :
: : : : varying :
: : : : sparsity :
: : : : between :
: : : : initial and :
: : : : final. :
: : : : exponent > 1 :
: : : : varies more :
: : : : slowly towards :
: : : : the end than :
: : : : the beginning :
The sparsity $$s_t$$ at global step $$t$$ is given by:
$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$
The interval between sparsity_function_begin_step and sparsity_function_end_step
is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
### Adding pruning ops to the training graph
The final step involves adding ops to the training graph that monitors the
distribution of the layer's weight magnitudes and determines the layer threshold
such masking all the weights below this threshold achieves the sparsity level
desired for the current training step. This can be achieved as follows:
```python
tf.app.flags.DEFINE_string(
'pruning_hparams', '',
"""Comma separated list of pruning-related hyperparameters""")
with tf.graph.as_default():
# Create global step variable
global_step = tf.train.get_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning specification
p = pruning.Pruning(pruning_hparams, global_step=global_step)
# Add conditional mask update op. Executing this op will update all
# the masks in the graph if the current global step is in the range
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
mask_update_op = p.conditional_mask_update_op()
# Add summaries to keep track of the sparsity in different layers during training
p.add_pruning_summaries()
with tf.train.MonitoredTrainingSession(...) as mon_sess:
# Run the usual training op in the tf session
mon_sess.run(train_op)
# Update the masks by running the mask_update_op
mon_sess.run(mask_update_op)
```
## Example: Pruning and training deep CNNs on the cifar10 dataset
Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
network architecture, setting up inputs etc. The additional changes needed to
incorporate pruning are captured in the following:
* [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py)
creates a deep CNN with the same architecture, but adds mask and threshold
variables for each of the weight tensors in the convolutional and
locally-connected layers.
* [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py)
add pruning ops to the training graph as described above.
To train the pruned version of cifar10:
```bash
$ examples_dir=contrib/model_pruning/examples
$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
```
Eval:
```shell
$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
```
TODO(suyoggupta): Add figures showing the sparsity function, sparsity for
different layers etc.

View File

@ -0,0 +1,46 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model pruning implementation in tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.model_pruning.python.layers.layers import masked_conv2d
from tensorflow.contrib.model_pruning.python.layers.layers import masked_convolution
from tensorflow.contrib.model_pruning.python.layers.layers import masked_fully_connected
from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell
from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedLSTMCell
from tensorflow.contrib.model_pruning.python.learning import train
from tensorflow.contrib.model_pruning.python.pruning import apply_mask
from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights
from tensorflow.contrib.model_pruning.python.pruning import get_masks
from tensorflow.contrib.model_pruning.python.pruning import get_thresholds
from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity
from tensorflow.contrib.model_pruning.python.pruning import get_weights
from tensorflow.contrib.model_pruning.python.pruning import Pruning
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'masked_convolution', 'masked_conv2d', 'masked_fully_connected',
'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask',
'get_masked_weights', 'get_masks', 'get_thresholds', 'get_weights',
'get_weight_sparsity', 'Pruning'
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,77 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Description:
# Example TensorFlow models for CIFAR-10
package(
default_visibility = [
"//tensorflow:internal",
],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "cifar10_input",
srcs = ["cifar10_input.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "cifar10_pruning",
srcs = ["cifar10_pruning.py"],
srcs_version = "PY2AND3",
deps = [
":cifar10_input",
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "cifar10_eval",
srcs = [
"cifar10_eval.py",
],
srcs_version = "PY2AND3",
deps = [
":cifar10_pruning",
],
)
py_binary(
name = "cifar10_train",
srcs = [
"cifar10_train.py",
],
srcs_version = "PY2AND3",
deps = [
":cifar10_pruning",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,178 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation for CIFAR-10.
Accuracy:
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
of data) as judged by cifar10_eval.py.
Speed:
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
accuracy after 100K steps in 8 hours of training time.
Usage:
Please see the tutorial and website for how to download the CIFAR-10
data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import datetime
import math
import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
FLAGS = None
def eval_once(saver, summary_writer, top_k_op, summary_op):
"""Run Eval once.
Args:
saver: Saver.
summary_writer: Summary writer.
top_k_op: Top K op.
summary_op: Summary op.
"""
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/cifar10_train/model.ckpt-0,
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('No checkpoint file found')
return
# Start the queue runners.
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
num_iter = int(math.ceil(FLAGS.num_examples / 128))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * 128
step = 0
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_k_op])
true_count += np.sum(predictions)
step += 1
# Compute precision @ 1.
precision = true_count / total_sample_count
print('%s: precision @ 1 = %.3f' % (datetime.datetime.now(), precision))
summary = tf.Summary()
summary.ParseFromString(sess.run(summary_op))
summary.value.add(tag='Precision @ 1', simple_value=precision)
summary_writer.add_summary(summary, global_step)
except Exception as e: # pylint: disable=broad-except
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
eval_data = FLAGS.eval_data == 'test'
images, labels = cifar10.inputs(eval_data=eval_data)
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate predictions.
top_k_op = tf.nn.in_top_k(logits, labels, 1)
# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# Build the summary operation based on the TF collection of Summaries.
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
while True:
eval_once(saver, summary_writer, top_k_op, summary_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir)
evaluate()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--eval_dir',
type=str,
default='/tmp/cifar10_eval',
help='Directory where to write event logs.')
parser.add_argument(
'--eval_data',
type=str,
default='test',
help="""Either 'test' or 'train_eval'.""")
parser.add_argument(
'--checkpoint_dir',
type=str,
default='/tmp/cifar10_train',
help="""Directory where to read model checkpoints.""")
parser.add_argument(
'--eval_interval_secs',
type=int,
default=60 * 5,
help='How often to run the eval.')
parser.add_argument(
'--num_examples',
type=int,
default=10000,
help='Number of examples to run.')
parser.add_argument(
'--run_once',
type=bool,
default=False,
help='Whether to run eval only once.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,256 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Routine for decoding the CIFAR-10 binary file format."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# Process images of this size. Note that this differs from the original CIFAR
# image size of 32 x 32. If one alters this number, then the entire model
# architecture will change and any model would need to be retrained.
IMAGE_SIZE = 24
# Global constants describing the CIFAR-10 data set.
NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
def read_cifar10(filename_queue):
"""Reads and parses examples from CIFAR10 data files.
Recommendation: if you want N-way read parallelism, call this function
N times. This will give you N independent Readers reading different
files & positions within those files, which will give better mixing of
examples.
Args:
filename_queue: A queue of strings with the filenames to read from.
Returns:
An object representing a single example, with the following fields:
height: number of rows in the result (32)
width: number of columns in the result (32)
depth: number of color channels in the result (3)
key: a scalar string Tensor describing the filename & record number
for this example.
label: an int32 Tensor with the label in the range 0..9.
uint8image: a [height, width, depth] uint8 Tensor with the image data
"""
class CIFAR10Record(object):
pass
result = CIFAR10Record()
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1 # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
# Every record consists of a label followed by the image, with a
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes
# Read a record, getting filenames from the filename_queue. No
# header or footer in the CIFAR-10 format, so we leave header_bytes
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8 that is record_bytes long.
record_bytes = tf.decode_raw(value, tf.uint8)
# The first bytes represent the label, which we convert from uint8->int32.
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(
tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[result.depth, result.height, result.width])
# Convert from [depth, height, width] to [height, width, depth].
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
"""Construct a queued batch of images and labels.
Args:
image: 3-D Tensor of [height, width, 3] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue.
Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer.
tf.summary.image('images', images)
return images, tf.reshape(label_batch, [batch_size])
def distorted_inputs(data_dir, batch_size):
"""Construct distorted input for CIFAR training using the Reader ops.
Args:
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for training the network. Note the many random
# distortions applied to the image.
# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Because these operations are not commutative, consider randomizing
# the order their operation.
distorted_image = tf.image.random_brightness(distorted_image,
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)
def inputs(eval_data, data_dir, batch_size):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
if not eval_data:
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
else:
filenames = [os.path.join(data_dir, 'test_batch.bin')]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for evaluation.
# Crop the central [height, width] of the image.
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
width, height)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(resized_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch *
min_fraction_of_examples_in_queue)
# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=False)

View File

@ -0,0 +1,395 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the CIFAR-10 network with additional variables to support pruning.
Summary of available functions:
# Compute input images and labels for training. If you would like to run
# evaluations, use inputs() instead.
inputs, labels = distorted_inputs()
# Compute inference on the model inputs to make a prediction.
predictions = inference(inputs)
# Compute the total loss of the prediction with respect to the labels.
loss = loss(predictions, labels)
# Create a graph to run one step of training with respect to the loss.
train_op = train(loss, global_step)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input
from tensorflow.contrib.model_pruning.python import pruning
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASSES = cifar10_input.NUM_CLASSES
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
BATCH_SIZE = 128
DATA_DIR = '/tmp/cifar10_data'
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
# names of the summaries when visualizing a model.
TOWER_NAME = 'tower'
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
def _activation_summary(x):
"""Helper to create summaries for activations.
Creates a summary that provides a histogram of activations.
Creates a summary that measures the sparsity of activations.
Args:
x: Tensor
Returns:
nothing
"""
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.scalar(tensor_name + '/sparsity',
tf.nn.zero_fraction(x))
def _variable_on_cpu(name, shape, initializer):
"""Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
initializer: initializer for Variable
Returns:
Variable Tensor
"""
with tf.device('/cpu:0'):
dtype = tf.float32
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
def _variable_with_weight_decay(name, shape, stddev, wd):
"""Helper to create an initialized Variable with weight decay.
Note that the Variable is initialized with a truncated normal distribution.
A weight decay is added only if one is specified.
Args:
name: name of the variable
shape: list of ints
stddev: standard deviation of a truncated Gaussian
wd: add L2Loss weight decay multiplied by this float. If None, weight
decay is not added for this Variable.
Returns:
Variable Tensor
"""
dtype = tf.float32
var = _variable_on_cpu(
name,
shape,
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not DATA_DIR:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(
data_dir=data_dir, batch_size=BATCH_SIZE)
return images, labels
def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not DATA_DIR:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
images, labels = cifar10_input.inputs(
eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)
return images, labels
def inference(images):
"""Build the CIFAR-10 model.
Args:
images: Images returned from distorted_inputs() or inputs().
Returns:
Logits.
"""
# We instantiate all variables using tf.get_variable() instead of
# tf.Variable() in order to share variables across multiple GPU training runs.
# If we only ran this model on a single GPU, we could simplify this function
# by replacing all instances of tf.get_variable() with tf.Variable().
#
# While instantiating conv and local layers, we add mask and threshold
# variables to the layer by calling the pruning.apply_mask() function.
# Note that the masks are applied only to the weight tensors
# conv1
with tf.variable_scope('conv1') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 3, 64],
stddev=5e-2,
wd=0.0)
conv = tf.nn.conv2d(
images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
pre_activation = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(pre_activation, name=scope.name)
_activation_summary(conv1)
# pool1
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
padding='SAME', name='pool1')
# norm1
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
name='norm1')
# conv2
with tf.variable_scope('conv2') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 64, 64],
stddev=5e-2,
wd=0.0)
conv = tf.nn.conv2d(
norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
pre_activation = tf.nn.bias_add(conv, biases)
conv2 = tf.nn.relu(pre_activation, name=scope.name)
_activation_summary(conv2)
# norm2
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
name='norm2')
# pool2
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME', name='pool2')
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
local3 = tf.nn.relu(
tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
name=scope.name)
_activation_summary(local3)
# local4
with tf.variable_scope('local4') as scope:
weights = _variable_with_weight_decay('weights', shape=[384, 192],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
local4 = tf.nn.relu(
tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
name=scope.name)
_activation_summary(local4)
# linear layer(WX + b),
# We don't apply softmax here because
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
# and performs the softmax internally for efficiency.
with tf.variable_scope('softmax_linear') as scope:
weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
stddev=1/192.0, wd=0.0)
biases = _variable_on_cpu('biases', [NUM_CLASSES],
tf.constant_initializer(0.0))
softmax_linear = tf.add(
tf.matmul(local4, pruning.apply_mask(weights, scope)),
biases,
name=scope.name)
_activation_summary(softmax_linear)
return softmax_linear
def loss(logits, labels):
"""Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
Args:
logits: Logits from inference().
labels: Labels from distorted_inputs or inputs(). 1-D tensor
of shape [batch_size]
Returns:
Loss tensor of type float.
"""
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def _add_loss_summaries(total_loss):
"""Add summaries for losses in CIFAR-10 model.
Generates moving average for all losses and associated summaries for
visualizing the performance of the network.
Args:
total_loss: Total loss from loss().
Returns:
loss_averages_op: op for generating moving averages of losses.
"""
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
losses = tf.get_collection('losses')
loss_averages_op = loss_averages.apply(losses + [total_loss])
# Attach a scalar summary to all individual losses and the total loss; do the
# same for the averaged version of the losses.
for l in losses + [total_loss]:
# Name each loss as '(raw)' and name the moving average version of the loss
# as the original loss name.
tf.summary.scalar(l.op.name + ' (raw)', l)
tf.summary.scalar(l.op.name, loss_averages.average(l))
return loss_averages_op
def train(total_loss, global_step):
"""Train CIFAR-10 model.
Create an optimizer and apply to all trainable variables. Add moving
average for all trainable variables.
Args:
total_loss: Total loss from loss().
global_step: Integer Variable counting the number of training steps
processed.
Returns:
train_op: op for training.
"""
# Variables that affect learning rate.
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
tf.summary.scalar('learning_rate', lr)
# Generate moving averages of all losses and associated summaries.
loss_averages_op = _add_loss_summaries(total_loss)
# Compute gradients.
with tf.control_dependencies([loss_averages_op]):
opt = tf.train.GradientDescentOptimizer(lr)
grads = opt.compute_gradients(total_loss)
# Apply gradients.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Add histograms for trainable variables.
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
# Add histograms for gradients.
for grad, var in grads:
if grad is not None:
tf.summary.histogram(var.op.name + '/gradients', grad)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
return train_op
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = DATA_DIR
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)

View File

@ -0,0 +1,159 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A binary to train pruned CIFAR-10 using a single GPU.
Accuracy:
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
data) as judged by cifar10_eval.py when target sparsity in
cifar10_pruning_spec.pbtxt is set to zero
Results:
Sparsity | Accuracy after 150K steps
-------- | -------------------------
0% | 86%
50% | 86%
75% | TODO(suyoggupta)
90% | TODO(suyoggupta)
95% | 77%
Usage:
Please see the tutorial and website for how to download the CIFAR-10
data set, compile the program and train the model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import datetime
import sys
import time
import tensorflow as tf
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
from tensorflow.contrib.model_pruning.python import pruning
FLAGS = None
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate loss.
loss = cifar10.loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning hyperparameters
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
# Use the pruning_obj to add ops to the training graph to update the masks
# The conditional_mask_update_op will update the masks only when the
# training step is in [begin_pruning_step, end_pruning_step] specified in
# the pruning spec proto
mask_update_op = pruning_obj.conditional_mask_update_op()
# Use the pruning_obj to add summaries to the graph to track the sparsity
# of each of the layers
pruning_obj.add_pruning_summaries()
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
self._start_time = time.time()
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
duration = time.time() - self._start_time
loss_value = run_values.results
if self._step % 10 == 0:
num_examples_per_step = 128
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
# Update the masks
mon_sess.run(mask_update_op)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/cifar10_train',
help='Directory where to write event logs and checkpoint.')
parser.add_argument(
'--pruning_hparams',
type=str,
default='',
help="""Comma separated list of pruning-related hyperparameters""")
parser.add_argument(
'--max_steps',
type=int,
default=1000000,
help='Number of batches to run.')
parser.add_argument(
'--log_device_placement',
type=bool,
default=False,
help='Whether to log device placement.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,477 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the core layer classes for model pruning and its functional aliases.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
MASK_COLLECTION = 'masks'
THRESHOLD_COLLECTION = 'thresholds'
MASKED_WEIGHT_COLLECTION = 'masked_weights'
WEIGHT_COLLECTION = 'kernel'
# The 'weights' part of the name is needed for the quantization library
# to recognize that the kernel should be quantized.
MASKED_WEIGHT_NAME = 'weights/masked_weight'
class _MaskedConv(base.Layer):
"""Abstract nD convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved
(actually cross-correlated) with the layer input to produce a tensor of
outputs. The weight tensor of this layer is masked.
If `use_bias` is True (and a `bias_initializer` is provided),
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
Arguments:
rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
length of the convolution window.
strides: An integer or tuple/list of n integers,
specifying the stride length of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, ..., channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, ...)`.
dilation_rate: An integer or tuple/list of n integers, specifying
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
activation: Activation function. Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will
be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
"""
def __init__(self,
rank,
filters,
kernel_size,
strides=1,
padding='valid',
data_format='channels_last',
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(_MaskedConv, self).__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs)
self.rank = rank
self.filters = filters
self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
self.strides = utils.normalize_tuple(strides, rank, 'strides')
self.padding = utils.normalize_padding(padding)
self.data_format = utils.normalize_data_format(data_format)
self.dilation_rate = utils.normalize_tuple(dilation_rate, rank,
'dilation_rate')
self.activation = activation
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
self.input_spec = base.InputSpec(ndim=self.rank + 2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
channel_axis = 1 if self.data_format == 'channels_first' else -1
if input_shape[channel_axis].value is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis].value
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.mask = self.add_variable(
name='mask',
shape=kernel_shape,
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=self.dtype)
self.kernel = self.add_variable(
name='kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
trainable=True,
dtype=self.dtype)
self.threshold = self.add_variable(
name='threshold',
shape=[],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=self.dtype)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
MASKED_WEIGHT_NAME)
ops.add_to_collection(MASK_COLLECTION, self.mask)
ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
if self.use_bias:
self.bias = self.add_variable(
name='bias',
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
trainable=True,
dtype=self.dtype)
else:
self.bias = None
self.input_spec = base.InputSpec(
ndim=self.rank + 2, axes={channel_axis: input_dim})
self.built = True
def call(self, inputs):
outputs = nn.convolution(
input=inputs,
filter=self.masked_kernel,
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self.padding.upper(),
data_format=utils.convert_data_format(self.data_format, self.rank + 2))
if self.bias is not None:
if self.data_format == 'channels_first':
if self.rank == 1:
# nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias
if self.rank == 2:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
if self.rank == 3:
# As of Mar 2017, direct addition is significantly slower than
# bias_add when computing gradients. To use bias_add, we collapse Z
# and Y into a single dimension to obtain a 4D input tensor.
outputs_shape = outputs.shape.as_list()
outputs_4d = array_ops.reshape(outputs, [
outputs_shape[0], outputs_shape[1],
outputs_shape[2] * outputs_shape[3], outputs_shape[4]
])
outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
outputs = array_ops.reshape(outputs_4d, outputs_shape)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
if self.activation is not None:
return self.activation(outputs)
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_dim = utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return tensor_shape.TensorShape([input_shape[0]] + new_space +
[self.filters])
else:
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return tensor_shape.TensorShape([input_shape[0], self.filters] +
new_space)
class MaskedConv2D(_MaskedConv):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
(actually cross-correlated) with the layer input to produce a tensor of
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
Arguments:
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
dilation_rate: An integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
activation: Activation function. Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will
be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
"""
def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format='channels_last',
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(MaskedConv2D, self).__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name,
**kwargs)
class MaskedFullyConnected(base.Layer):
"""Fully-connected layer class with masked weights.
This layer implements the operation:
`outputs = activation(inputs.kernel + bias)`
Where `activation` is the activation function passed as the `activation`
argument (if not `None`), `kernel` is a weights matrix created by the layer,
and `bias` is a bias vector created by the layer
(only if `use_bias` is `True`).
Note: if the input to the layer has a rank greater than 2, then it is
flattened prior to the initial matrix multiply by `kernel`.
Arguments:
units: Integer or Long, dimensionality of the output space.
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: Initializer function for the weight matrix.
bias_initializer: Initializer function for the bias.
kernel_regularizer: Regularizer function for the weight matrix.
bias_regularizer: Regularizer function for the bias.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (callable).
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: Initializer instance (or name) for the weight matrix.
bias_initializer: Initializer instance (or name) for the bias.
kernel_regularizer: Regularizer instance for the weight matrix (callable)
bias_regularizer: Regularizer instance for the bias (callable).
activity_regularizer: Regularizer instance for the output (callable)
kernel: Weight matrix (TensorFlow variable or tensor).
bias: Bias vector, if applicable (TensorFlow variable or tensor).
"""
def __init__(self,
units,
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(MaskedFullyConnected, self).__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs)
self.units = units
self.activation = activation
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
self.input_spec = base.InputSpec(min_ndim=2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if input_shape[-1].value is None:
raise ValueError('The last dimension of the inputs to `Dense` '
'should be defined. Found `None`.')
self.input_spec = base.InputSpec(
min_ndim=2, axes={-1: input_shape[-1].value})
self.kernel = self.add_variable(
'kernel',
shape=[input_shape[-1].value, self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
dtype=self.dtype,
trainable=True)
self.mask = self.add_variable(
name='mask',
shape=[input_shape[-1].value, self.units],
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=self.dtype)
self.threshold = self.add_variable(
name='threshold',
shape=[],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=self.dtype)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
MASKED_WEIGHT_NAME)
ops.add_to_collection(MASK_COLLECTION, self.mask)
ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
if self.use_bias:
self.bias = self.add_variable(
'bias',
shape=[
self.units,
],
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
shape = inputs.get_shape().as_list()
output_shape = shape[:-1] + [self.units]
if len(output_shape) > 2:
# Broadcasting is required for the inputs.
outputs = standard_ops.tensordot(inputs, self.masked_kernel,
[[len(shape) - 1], [0]])
# Reshape the output back to the original ndim of the input.
outputs.set_shape(output_shape)
else:
outputs = standard_ops.matmul(inputs, self.masked_kernel)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
input_shape = input_shape.with_rank_at_least(2)
if input_shape[-1].value is None:
raise ValueError(
'The innermost dimension of input_shape must be defined, but saw: %s'
% input_shape)
return input_shape[:-1].concatenate(self.units)

View File

@ -0,0 +1,364 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow layers with added variables for parameter masking.
Branched from tensorflow/contrib/layers/python/layers/layers.py
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
def _model_variable_getter(getter,
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
rename=None,
use_resource=None,
**_):
"""Getter that uses model_variable for compatibility with core layers."""
short_name = name.split('/')[-1]
if rename and short_name in rename:
name_components = name.split('/')
name_components[-1] = rename[short_name]
name = '/'.join(name_components)
return variables.model_variable(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
collections=collections,
trainable=trainable,
caching_device=caching_device,
partitioner=partitioner,
custom_getter=getter,
use_resource=use_resource)
def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""
# VariableScope will nest the getters
def layer_variable_getter(getter, *args, **kwargs):
kwargs['rename'] = rename
return _model_variable_getter(getter, *args, **kwargs)
return layer_variable_getter
def _add_variable_to_collections(variable, collections_set, collections_name):
"""Adds variable (or all its parts) to all collections with that name."""
collections = utils.get_variable_collections(collections_set,
collections_name) or []
variables_list = [variable]
if isinstance(variable, tf_variables.PartitionedVariable):
variables_list = [v for v in variable]
for collection in collections:
for var in variables_list:
if var not in ops.get_collection(collection):
ops.add_to_collection(collection, var)
@add_arg_scope
def masked_convolution(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds an 2D convolution followed by an optional batch_norm layer.
The layer creates a mask variable on top of the weight variable. The input to
the convolution operation is the elementwise multiplication of the mask
variable and the weigh
It is required that 1 <= N <= 3.
`convolution` creates a variable called `weights`, representing the
convolutional kernel, that is convolved (actually cross-correlated) with the
`inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
provided (such as `batch_norm`), it is then applied. Otherwise, if
`normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
variable would be created and added the activations. Finally, if
`activation_fn` is not `None`, it is applied to the activations as well.
Performs atrous convolution with input stride/dilation rate equal to `rate`
if a value > 1 for any dimension of `rate` is specified. In this case
`stride` values != 1 are not supported.
Args:
inputs: A Tensor of rank N+2 of shape
`[batch_size] + input_spatial_shape + [in_channels]` if data_format does
not start with "NC" (default), or
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
with "NC".
num_outputs: Integer, the number of output filters.
kernel_size: A sequence of N positive integers specifying the spatial
dimensions of of the filters. Can be a single integer to specify the same
value for all spatial dimensions.
stride: A sequence of N positive integers specifying the stride at which to
compute output. Can be a single integer to specify the same value for all
spatial dimensions. Specifying any `stride` value != 1 is incompatible
with specifying any `rate` value != 1.
padding: One of `"VALID"` or `"SAME"`.
data_format: A string or None. Specifies whether the channel dimension of
the `input` and output is the last dimension (default, or if `data_format`
does not start with "NC"), or the second dimension (if `data_format`
starts with "NC"). For N=1, the valid values are "NWC" (default) and
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
For N=3, the valid values are "NDHWC" (default) and "NCDHW".
rate: A sequence of N positive integers specifying the dilation rate to use
for atrous convolution. Can be a single integer to specify the same
value for all spatial dimensions. Specifying any `rate` value != 1 is
incompatible with specifying any `stride` value != 1.
activation_fn: Activation function. The default value is a ReLU function.
Explicitly set it to None to skip it and maintain a linear activation.
normalizer_fn: Normalization function to use instead of `biases`. If
`normalizer_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
default set to None for no normalizer function
normalizer_params: Normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or
a dictionary containing a different list of collection per variable.
outputs_collections: Collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_scope`.
Returns:
A tensor representing the output of the operation.
Raises:
ValueError: If `data_format` is invalid.
ValueError: Both 'rate' and `stride` are not uniformly 1.
"""
if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('Invalid data_format: %r' % (data_format,))
layer_variable_getter = _build_variable_getter({
'bias': 'biases',
'kernel': 'weights'
})
with variable_scope.variable_scope(
scope, 'Conv', [inputs], reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
input_rank = inputs.get_shape().ndims
if input_rank == 3:
raise ValueError('Sparse Convolution not supported for input with rank',
input_rank)
elif input_rank == 4:
layer_class = core.MaskedConv2D
elif input_rank == 5:
raise ValueError('Sparse Convolution not supported for input with rank',
input_rank)
else:
raise ValueError('Sparse Convolution not supported for input with rank',
input_rank)
if data_format is None or data_format == 'NHWC':
df = 'channels_last'
elif data_format == 'NCHW':
df = 'channels_first'
else:
raise ValueError('Unsupported data fromat', data_format)
layer = layer_class(
filters=num_outputs,
kernel_size=kernel_size,
strides=stride,
padding=padding,
data_format=df,
dilation_rate=rate,
activation=None,
use_bias=not normalizer_fn and biases_initializer,
kernel_initializer=weights_initializer,
bias_initializer=biases_initializer,
kernel_regularizer=weights_regularizer,
bias_regularizer=biases_regularizer,
activity_regularizer=None,
trainable=trainable,
name=sc.name,
dtype=inputs.dtype.base_dtype,
_scope=sc,
_reuse=reuse)
outputs = layer.apply(inputs)
# Add variables to collections.
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
if layer.use_bias:
_add_variable_to_collections(layer.bias, variables_collections, 'biases')
if normalizer_fn is not None:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, outputs)
masked_conv2d = masked_convolution
@add_arg_scope
def masked_fully_connected(
inputs,
num_outputs,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a sparse fully connected layer. The weight matrix is masked.
`fully_connected` creates a variable called `weights`, representing a fully
connected weight matrix, which is multiplied by the `inputs` to produce a
`Tensor` of hidden units. If a `normalizer_fn` is provided (such as
`batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
None and a `biases_initializer` is provided then a `biases` variable would be
created and added the hidden units. Finally, if `activation_fn` is not `None`,
it is applied to the hidden units as well.
Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
prior to the initial matrix multiply by `weights`.
Args:
inputs: A tensor of at least rank 2 and static value for the last dimension;
i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
num_outputs: Integer or long, the number of output units in the layer.
activation_fn: Activation function. The default value is a ReLU function.
Explicitly set it to None to skip it and maintain a linear activation.
normalizer_fn: Normalization function to use instead of `biases`. If
`normalizer_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
default set to None for no normalizer function
normalizer_params: Normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or
a dictionary containing a different list of collections per variable.
outputs_collections: Collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_scope.
Returns:
The tensor variable representing the result of the series of operations.
Raises:
ValueError: If x has rank less than 2 or if its last dimension is not set.
"""
if not isinstance(num_outputs, six.integer_types):
raise ValueError('num_outputs should be int or long, got %s.' %
(num_outputs,))
layer_variable_getter = _build_variable_getter({
'bias': 'biases',
'kernel': 'weights'
})
with variable_scope.variable_scope(
scope,
'fully_connected', [inputs],
reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
layer = core.MaskedFullyConnected(
units=num_outputs,
activation=None,
use_bias=not normalizer_fn and biases_initializer,
kernel_initializer=weights_initializer,
bias_initializer=biases_initializer,
kernel_regularizer=weights_regularizer,
bias_regularizer=biases_regularizer,
activity_regularizer=None,
trainable=trainable,
name=sc.name,
dtype=inputs.dtype.base_dtype,
_scope=sc,
_reuse=reuse)
outputs = layer.apply(inputs)
# Add variables to collections.
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
if layer.bias is not None:
_add_variable_to_collections(layer.bias, variables_collections, 'biases')
# Apply normalizer function / layer.
if normalizer_fn is not None:
if not normalizer_params:
normalizer_params = {}
outputs = normalizer_fn(outputs, **normalizer_params)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, outputs)

View File

@ -0,0 +1,139 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for imagingvision.intelligence.tensorflow.model_pruning.layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.model_pruning.python.layers import core_layers
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class MaskedConvolutionLayerTest(test.TestCase):
def setUp(self):
super(MaskedConvolutionLayerTest, self).setUp()
self.height, self.width = 7, 9
def testInvalidRank3(self):
input_tensor = array_ops.ones((self.height, self.width, 3))
with self.assertRaisesRegexp(ValueError, 'rank'):
layers.masked_conv2d(input_tensor, 32, 3)
def testInvalidRank5(self):
input_tensor = array_ops.ones((8, 8, self.height, self.width, 3))
with self.assertRaisesRegexp(ValueError, 'rank'):
layers.masked_conv2d(input_tensor, 32, 3)
def testSingleConvMaskAdded(self):
kernel_size = 3
input_depth, output_depth = 8, 32
input_tensor = array_ops.ones((8, self.height, self.width, input_depth))
layers.masked_conv2d(input_tensor, output_depth, kernel_size)
masks = ops.get_collection(core_layers.MASK_COLLECTION)
self.assertEqual(len(masks), 1)
self.assertListEqual(masks[0].get_shape().as_list(),
[kernel_size, kernel_size, input_depth, output_depth])
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
self.assertEqual(len(masked_weight), 1)
self.assertListEqual(masked_weight[0].get_shape().as_list(),
[kernel_size, kernel_size, input_depth, output_depth])
def testMultipleConvMaskAdded(self):
number_of_layers = 5
kernel_size = 3
base_depth = 4
depth_step = 7
input_tensor = array_ops.ones((8, self.height, self.width, base_depth))
top_layer = input_tensor
for ix in range(number_of_layers):
top_layer = layers.masked_conv2d(top_layer, base_depth +
(ix + 1) * depth_step, kernel_size)
masks = ops.get_collection(core_layers.MASK_COLLECTION)
self.assertEqual(len(masks), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masks[ix].get_shape().as_list(), [
kernel_size, kernel_size, base_depth + ix * depth_step,
base_depth + (ix + 1) * depth_step
])
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
self.assertEqual(len(masked_weight), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
kernel_size, kernel_size, base_depth + ix * depth_step,
base_depth + (ix + 1) * depth_step
])
class MaskedFullyConnectedLayerTest(test.TestCase):
def testSingleFCMaskAdded(self):
input_depth, output_depth = 8, 32
input_tensor = array_ops.ones((5, input_depth))
layers.masked_fully_connected(input_tensor, output_depth)
masks = ops.get_collection(core_layers.MASK_COLLECTION)
self.assertEqual(len(masks), 1)
self.assertListEqual(masks[0].get_shape().as_list(),
[input_depth, output_depth])
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
self.assertEqual(len(masked_weight), 1)
self.assertListEqual(masked_weight[0].get_shape().as_list(),
[input_depth, output_depth])
def testMultipleConvMaskAdded(self):
number_of_layers = 5
base_depth = 4
depth_step = 7
input_tensor = array_ops.ones((8, base_depth))
top_layer = input_tensor
for ix in range(number_of_layers):
top_layer = layers.masked_fully_connected(top_layer, base_depth +
(ix + 1) * depth_step)
masks = ops.get_collection(core_layers.MASK_COLLECTION)
self.assertEqual(len(masks), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masks[ix].get_shape().as_list(), [
base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
])
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
self.assertEqual(len(masked_weight), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
])
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,340 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing RNN Cells with pruning.
This module implements BasicLSTMCell and LSTMCell with pruning.
Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.model_pruning.python.layers import core_layers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell as tf_rnn
class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
"""Basic LSTM recurrent network cell with pruning.
Overrides the call method of tensorflow BasicLSTMCell and injects the weight
masks
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
that follows.
"""
def __init__(self,
num_units,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None,
name=None):
"""Initialize the basic LSTM cell with pruning.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained
checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
When restoring from CudnnLSTM-trained checkpoints, must use
CudnnCompatibleLSTMCell instead.
"""
super(MaskedBasicLSTMCell, self).__init__(
num_units,
forget_bias=forget_bias,
state_is_tuple=state_is_tuple,
activation=activation,
reuse=reuse,
name=name)
def build(self, inputs_shape):
# Call the build method of the parent class.
super(MaskedBasicLSTMCell, self).build(inputs_shape)
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
name="mask",
shape=[input_depth + h_depth, 4 * h_depth],
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=self.dtype)
self._threshold = self.add_variable(
name="threshold",
shape=[],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=self.dtype)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
core_layers.MASKED_WEIGHT_NAME)
if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
self._masked_kernel)
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
def call(self, inputs, state):
"""Long short-term memory cell (LSTM) with masks for pruning.
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size, self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size, 2 * self.state_size]`.
Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
"""
sigmoid = math_ops.sigmoid
one = constant_op.constant(1, dtype=dtypes.int32)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._masked_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
new_c = add(
multiply(c, sigmoid(add(f, forget_bias_tensor))),
multiply(sigmoid(i), self._activation(j)))
new_h = multiply(self._activation(new_c), sigmoid(o))
if self._state_is_tuple:
new_state = tf_rnn.LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
class MaskedLSTMCell(tf_rnn.LSTMCell):
"""LSTMCell with pruning.
Overrides the call method of tensorflow LSTMCell and injects the weight masks.
Masks are applied to only the weight matrix of the LSTM and not the
projection matrix.
"""
def __init__(self,
num_units,
use_peepholes=False,
cell_clip=None,
initializer=None,
num_proj=None,
proj_clip=None,
num_unit_shards=None,
num_proj_shards=None,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None):
"""Initialize the parameters for an LSTM cell with masks for pruning.
Args:
num_units: int, The number of units in the LSTM cell
use_peepholes: bool, set True to enable diagonal/peephole connections.
cell_clip: (optional) A float value, if provided the cell state is clipped
by this value prior to the cell output activation.
initializer: (optional) The initializer to use for the weight and
projection matrices.
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
provided, then the projected values are clipped elementwise to within
`[-proj_clip, proj_clip]`.
num_unit_shards: Deprecated, will be removed by Jan. 2017.
Use a variable_scope partitioner instead.
num_proj_shards: Deprecated, will be removed by Jan. 2017.
Use a variable_scope partitioner instead.
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training. Must set it manually to `0.0` when restoring from
CudnnLSTM trained checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.
activation: Activation function of the inner states. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
When restoring from CudnnLSTM-trained checkpoints, must use
CudnnCompatibleLSTMCell instead.
"""
super(MaskedLSTMCell, self).__init__(
num_units,
use_peepholes=use_peepholes,
cell_clip=cell_clip,
initializer=initializer,
num_proj=num_proj,
proj_clip=proj_clip,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
forget_bias=forget_bias,
state_is_tuple=state_is_tuple,
activation=activation,
reuse=reuse)
def build(self, inputs_shape):
# Call the build method of the parent class.
super(MaskedLSTMCell, self).build(inputs_shape)
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
name="mask",
shape=[input_depth + h_depth, 4 * h_depth],
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=self.dtype)
self._threshold = self.add_variable(
name="threshold",
shape=[],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=self.dtype)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
core_layers.MASKED_WEIGHT_NAME)
if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
self._masked_kernel)
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
inputs: input Tensor, 2D, `[batch, num_units].
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.
Returns:
A tuple containing:
- A `2-D, [batch, output_dim]`, Tensor representing the output of the
LSTM after reading `inputs` when previous state was `state`.
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- Tensor(s) representing the new state of LSTM after reading `inputs` when
the previous state was `state`. Same type and shape(s) as `state`.
Raises:
ValueError: If input size cannot be inferred from inputs via
static shape inference.
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
sigmoid = math_ops.sigmoid
if self._state_is_tuple:
(c_prev, m_prev) = state
else:
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = math_ops.matmul(
array_ops.concat([inputs, m_prev], 1), self._masked_kernel)
lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
# Diagonal connections
if self._use_peepholes:
c = (
sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
c = (
sigmoid(f + self._forget_bias) * c_prev +
sigmoid(i) * self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
m = math_ops.matmul(m, self._proj_kernel)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
new_state = (
tf_rnn.LSTMStateTuple(c, m)
if self._state_is_tuple else array_ops.concat([c, m], 1))
return m, new_state

View File

@ -0,0 +1,85 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for creating different number of masks in rnn_cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import rnn_cells
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell as tf_rnn_cells
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class RnnCellsTest(test.TestCase):
def setUp(self):
super(RnnCellsTest, self).setUp()
self.batch_size = 8
self.dim = 10
def testMaskedBasicLSTMCell(self):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
with self.test_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
h = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
state = tf_rnn_cells.LSTMStateTuple(c, h)
lstm_cell = rnn_cells.MaskedBasicLSTMCell(self.dim)
lstm_cell(inputs, state)
self.assertEqual(len(pruning.get_masks()), expected_num_masks)
self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
self.assertEqual(len(pruning.get_weights()), expected_num_masks)
for mask in pruning.get_masks():
self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
for weight in pruning.get_weights():
self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def testMaskedLSTMCell(self):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
with self.test_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
h = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
state = tf_rnn_cells.LSTMStateTuple(c, h)
lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
lstm_cell(inputs, state)
self.assertEqual(len(pruning.get_masks()), expected_num_masks)
self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
self.assertEqual(len(pruning.get_weights()), expected_num_masks)
for mask in pruning.get_masks():
self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
for weight in pruning.get_weights():
self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,188 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py
to support training of pruned models
*******************************************************************
* A simple working training script with support for model pruning *
*******************************************************************
# Load data and create the model:
images, labels = LoadData(...)
predictions = MyModel(images)
# Define the loss:
slim.losses.log_loss(predictions, labels)
total_loss = slim.losses.get_total_loss()
# Define the optimizer:
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
# Set up sparsity
sparsity = pruning.setup_gradual_sparsity(self.global_step)
# Create mask update op
mask_update_op = pruning.add_mask_update_ip(sparsity)
# Run training.
learning.train(train_op,
my_log_dir,
mask_update_op)
see contrib/slim/python/slim/learning.py for additional examples
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import slim as _slim
_USE_DEFAULT = 0
train_step = _slim.learning.train_step
def train(train_op,
logdir,
mask_update_op,
train_step_fn=train_step,
train_step_kwargs=_USE_DEFAULT,
log_every_n_steps=1,
graph=None,
master='',
is_chief=True,
global_step=None,
number_of_steps=None,
init_op=_USE_DEFAULT,
init_feed_dict=None,
local_init_op=_USE_DEFAULT,
init_fn=None,
ready_op=_USE_DEFAULT,
summary_op=_USE_DEFAULT,
save_summaries_secs=600,
summary_writer=_USE_DEFAULT,
startup_delay_steps=0,
saver=None,
save_interval_secs=600,
sync_optimizer=None,
session_config=None,
trace_every_n_steps=None):
"""Wrapper around tf-slim's train function.
Runs a training loop using a TensorFlow supervisor.
When the sync_optimizer is supplied, gradient updates are applied
synchronously. Otherwise, gradient updates are applied asynchronous.
Args:
train_op: A `Tensor` that, when executed, will apply the gradients and
return the loss value.
logdir: The directory where training logs are written to. If None, model
checkpoints and summaries will not be written.
mask_update_op: Operation that upon execution updates the weight masks and
thresholds.
train_step_fn: The function to call in order to execute a single gradient
step. The function must have take exactly four arguments: the current
session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
default, two `Boolean`, scalar ops called "should_stop" and "should_log"
are provided.
log_every_n_steps: The frequency, in terms of global steps, that the loss
and global step and logged.
graph: The graph to pass to the supervisor. If no graph is supplied the
default graph is used.
master: The address of the tensorflow master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
global_step: The `Tensor` representing the global step. If left as `None`,
then slim.variables.get_or_create_global_step() is used.
number_of_steps: The max number of gradient steps to take during training,
as measured by 'global_step': training will stop if global_step is
greater than 'number_of_steps'. If the value is left as None, training
proceeds indefinitely.
init_op: The initialization operation. If left to its default value, then
the session is initialized by calling `tf.global_variables_initializer()`.
init_feed_dict: A feed dictionary to use when executing the `init_op`.
local_init_op: The local initialization operation. If left to its default
value, then the session is initialized by calling
`tf.local_variables_initializer()` and `tf.tables_initializer()`.
init_fn: An optional callable to be executed after `init_op` is called. The
callable must accept one argument, the session being initialized.
ready_op: Operation to check if the model is ready to use. If left to its
default value, then the session checks for readiness by calling
`tf.report_uninitialized_variables()`.
summary_op: The summary operation.
save_summaries_secs: How often, in seconds, to save summaries.
summary_writer: `SummaryWriter` to use. Can be `None`
to indicate that no summaries should be written. If unset, we
create a SummaryWriter.
startup_delay_steps: The number of steps to wait for before beginning. Note
that this must be 0 if a sync_optimizer is supplied.
saver: Saver to save checkpoints. If None, a default one will be created
and used.
save_interval_secs: How often, in seconds, to save the model to `logdir`.
sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
them. If the argument is supplied, gradient updates will be synchronous.
If left as `None`, gradient updates will be asynchronous.
session_config: An instance of `tf.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used.
trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
and add it to the summaries every `trace_every_n_steps`. If None, no trace
information will be produced or saved.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `train_op` is empty or if `startup_delay_steps` is
non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
provided.
"""
def train_step_with_pruning_fn(sess, train_op, global_step,
train_step_kwargs):
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
sess.run(mask_update_op)
return total_loss, should_stop
total_loss, _ = _slim.learning.train(
train_op,
logdir,
train_step_fn=train_step_with_pruning_fn,
train_step_kwargs=train_step_kwargs,
log_every_n_steps=log_every_n_steps,
graph=graph,
master=master,
is_chief=is_chief,
global_step=global_step,
number_of_steps=number_of_steps,
init_op=init_op,
init_feed_dict=init_feed_dict,
local_init_op=local_init_op,
init_fn=init_fn,
ready_op=ready_op,
summary_op=summary_op,
save_summaries_secs=save_summaries_secs,
summary_writer=summary_writer,
startup_delay_steps=startup_delay_steps,
saver=saver,
save_interval_secs=save_interval_secs,
sync_optimizer=sync_optimizer,
session_config=session_config,
trace_every_n_steps=trace_every_n_steps)
return total_loss

View File

@ -0,0 +1,585 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to add support for magnitude-based model pruning.
# Adds variables and ops to the graph to enable
# elementwise masking of weights
apply_mask(weights)
# Returns a list containing the sparsity of each of the weight tensors
get_weight_sparsity()
# Returns a list of all the masked weight tensorflow variables
get_masked_weights()
# Returns a list of all the mask tensorflow variables
get_masks()
# Returns a list of all the thresholds
get_thresholds()
# Returns a list of all the weight tensors that have been masked
get_weights()
The Pruning class uses a proto (defined in pruning.proto) to set up the
parameters for a pruning specification. Here's a typical usage:
# Initialize a pruning spec from a proto
pruning_spec = '/tmp/pruning.pb'
p = Pruning(pruning_spec)
# Add mask update ops to the graph
mask_update_op = p.conditional_mask_update_op()
# Add the summaries
p.add_pruning_summaries()
# Run the op
session.run(mask_update_op)
# An object of the pruning also accepts externally defined sparsity:
sparsity = tf.Variable(0.5, name = "ConstantSparsity")
pruning_spec = '/tmp/pruning.pb'
p = Pruning(pruning_spec, sparsity=sparsity)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
from tensorflow.contrib.training.python.training import hparam
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
_MASK_COLLECTION = core.MASK_COLLECTION
_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
def _weight_mask_variable(var, scope):
"""Create a mask for the weights.
This function adds a variable 'mask' to the graph.
Args:
var: the weight variable that needs to be masked
scope: The variable scope of the variable var
Returns:
the mask variable of the same size and shape as var, initialized to all 1s.
"""
with variable_scope.variable_scope(scope):
mask = variable_scope.get_variable(
'mask',
var.get_shape(),
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=var.dtype)
return mask
def _weight_threshold_variable(var, scope):
"""Create a scalar threshold for the weights.
This function adds a variable
'threshold' to the graph.
Args:
var: The weight variable that needs to be masked
scope: The variable scope of the variable var
Returns:
a scalar threshold variable initialized to 0.
"""
with variable_scope.variable_scope(scope):
threshold = variable_scope.get_variable(
'threshold', [],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=var.dtype)
return threshold
def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
"""Return histogram of values.
Given the tensor `values`, this operation returns a rank 1 histogram counting
the number of entries in `values` that fell into every bin. The bins are
equal width and determined by the arguments `value_range` and `nbins`.
Args:
values: Numeric `Tensor`.
value_range: Shape [2] `Tensor` of same `dtype` as `values`.
values <= value_range[0] will be mapped to hist[0],
values >= value_range[1] will be mapped to hist[-1].
nbins: Scalar `int32 Tensor`. Number of histogram bins.
dtype: dtype for returned histogram.
name: A name for this operation (defaults to 'histogram').
Returns:
A 1-D `Tensor` holding histogram of values.
"""
with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
values = ops.convert_to_tensor(values, name='values')
values = gen_array_ops.reshape(values, [-1])
value_range = ops.convert_to_tensor(value_range, name='value_range')
nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins')
nbins_float = math_ops.cast(nbins, values.dtype)
# Map tensor values that fall within value_range to [0, 1].
scaled_values = math_ops.truediv(
values - value_range[0],
value_range[1] - value_range[0],
name='scaled_values')
# map tensor values within the open interval value_range to {0,.., nbins-1},
# values outside the open interval will be zero or less, or nbins or more.
indices = math_ops.floor(nbins_float * scaled_values, name='indices')
# Clip edge cases (e.g. value = value_range[1]) or "outliers."
indices = math_ops.cast(
clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32)
return math_ops.unsorted_segment_sum(
array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
def _determine_partitioned_axis(partitioned_variable):
partitioned_axis = 0
concatenated_variable_shape = partitioned_variable.get_shape()
for partition in partitioned_variable:
partition_shape = partition.get_shape()
maybe_partitioned_axis = np.less(partition_shape,
concatenated_variable_shape)
# Sanity check: make sure number of partitioned axis == 1
if np.count_nonzero(maybe_partitioned_axis) != 1:
raise ValueError('Number of partitioned axes %s not equal to 1' %
np.count_nonzero(maybe_partitioned_axis))
partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
return partitioned_axis
def _variable_assign(var, new_value):
return state_ops.assign(var, new_value, name=var.op.name + '_assign')
def _partitioned_variable_assign(partitioned_var, new_value):
"""Assign op for partitioned variables.
Args:
partitioned_var: A partitioned tensotflow variable
new_value: Value to be assigned to the variable var
Returns:
A tensorflow op that groups the assign ops for each of the variable slices
"""
# Determine which axis was used to partition the variable. Currently
# tensorflow allows partitioning variable only along 1 axis.
axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis(
partitioned_var)
partition_sizes = np.array(
[partition.get_shape()[axis] for partition in partitioned_var])
new_partitioned_values = array_ops.split(
new_value,
ops.convert_to_tensor(partition_sizes, dtype=np.int32),
axis=axis)
op_list = []
for partition in partitioned_var:
op_list.append(
_variable_assign(partition, new_partitioned_values[len(op_list)]))
return control_flow_ops.group(
*op_list, name=partitioned_var.name + '_group_assign')
def apply_mask(x, scope=''):
"""Apply mask to a given weight tensor.
Args:
x: Input weight tensor
scope: The current variable scope. Defaults to ""
Returns:
Tensor representing masked_weights
"""
mask = _weight_mask_variable(x, scope)
threshold = _weight_threshold_variable(x, scope)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
# Make sure the mask for a given variable are not added multiple times to the
# collection. This is particularly important when applying mask to RNN's
# weight variables
if mask not in ops.get_collection_ref(_MASK_COLLECTION):
ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
ops.add_to_collection(_MASK_COLLECTION, mask)
ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
ops.add_to_collection(_WEIGHT_COLLECTION, x)
return masked_weights
def get_masked_weights():
return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
def get_masks():
return ops.get_collection(_MASK_COLLECTION)
def get_thresholds():
return ops.get_collection(_THRESHOLD_COLLECTION)
def get_weights():
return ops.get_collection(_WEIGHT_COLLECTION)
def get_weight_sparsity():
"""Get sparsity of the weights.
Args:
None
Returns:
A list containing the sparsity of each of the weight tensors
"""
masks = get_masks()
return [nn_impl.zero_fraction(mask) for mask in masks]
def get_pruning_hparams():
"""Get a tf.HParams object with the default values for the hyperparameters.
name: string
name of the pruning specification. Used for adding summaries and ops under
a common tensorflow name_scope
begin_pruning_step: integer
the global step at which to begin pruning
end_pruning_step: integer
the global step at which to terminate pruning. Defaults to -1 implying
that pruning continues till the training stops
do_not_prune: list of strings
list of layers that are not pruned
threshold_decay: float
the decay factor to use for exponential decay of the thresholds
pruning_frequency: integer
How often should the masks be updated? (in # of global_steps)
nbins: integer
number of bins to use for histogram computation
initial_sparsity: float
initial sparsity value
target_sparsity: float
target sparsity value
sparsity_function_begin_step: integer
the global step at this which the gradual sparsity function begins to
take effect
sparsity_function_end_step: integer
the global step used as the end point for the gradual sparsity function
sparsity_function_exponent: float
exponent = 1 is linearly varying sparsity between initial and final.
exponent > 1 varies more slowly towards the end than the beginning
We use the following sparsity function:
num_steps = (sparsity_function_end_step -
sparsity_function_begin_step)/pruning_frequency
sparsity(step) = (initial_sparsity - target_sparsity)*
[1-step/(num_steps -1)]**exponent + target_sparsity
Args:
None
Returns:
tf.HParams object initialized to default values
"""
return hparam.HParams(
name='model_pruning',
begin_pruning_step=0,
end_pruning_step=-1,
do_not_prune=[''],
threshold_decay=0.9,
pruning_frequency=10,
nbins=255,
initial_sparsity=0,
target_sparsity=0.5,
sparsity_function_begin_step=0,
sparsity_function_end_step=100,
sparsity_function_exponent=3)
class Pruning(object):
def __init__(self,
spec=None,
global_step=None,
sparsity=None,
partitioner=None):
"""Set up the specification for model pruning.
If a spec is provided, the sparsity is set up based on the sparsity_function
in the spec. The effect of sparsity_function is overridden if the sparsity
variable is passed to the constructor. This enables setting up arbitrary
sparsity profiles externally and passing it to this pruning functions.
Args:
spec: Pruning spec as defined in pruning.proto
global_step: A tensorflow variable that is used while setting up the
sparsity function
sparsity: A tensorflow scalar variable storing the sparsity
partitioner: The tensorflow partitioner function used to distribute
parameters across shards
"""
# Pruning specification
self._spec = spec if spec else get_pruning_hparams()
# A tensorflow variable that tracks the sparsity function.
# If not provided as input, the graph must already contain the global_step
# variable before calling this constructor.
self._global_step = self._setup_global_step(global_step)
# Stores the tensorflow sparsity variable.
# Built using self._setup_sparsity() or provided externally
self._sparsity = sparsity if sparsity else self._setup_sparsity()
# Stores the partitioner function uses to partition variables across tasks/
self._partitioner = partitioner
# List of tensorflow assignments ops for new masks and thresholds
self._assign_ops = []
# Tensorflow variable keeping track of the last global step when the masks
# were updated
self._last_update_step = self._setup_last_update_step()
def _setup_global_step(self, global_step):
graph_global_step = global_step
if graph_global_step is None:
graph_global_step = training_util.get_global_step()
return math_ops.cast(graph_global_step, np.int32)
def _setup_sparsity(self):
begin_step = self._spec.sparsity_function_begin_step
end_step = self._spec.sparsity_function_end_step
initial_sparsity = self._spec.initial_sparsity
target_sparsity = self._spec.target_sparsity
exponent = self._spec.sparsity_function_exponent
if begin_step >= end_step:
raise ValueError(
'Pruning must begin before it can end. begin_step=%d, end_step=%d' %
(begin_step, end_step))
with ops.name_scope(self._spec.name):
p = math_ops.minimum(1.0,
math_ops.maximum(
0.0,
math_ops.div(
math_ops.cast(self._global_step - begin_step,
np.float32),
end_step - begin_step)))
sparsity = math_ops.add(
math_ops.multiply(initial_sparsity - target_sparsity,
math_ops.pow(1 - p, exponent)),
target_sparsity,
name='sparsity')
return sparsity
def _setup_last_update_step(self):
with variable_scope.variable_scope(self._spec.name) as scope:
try:
last_update_step = variable_scope.get_variable(
'last_mask_update_step', [],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=np.int32)
except ValueError:
scope.reuse_variables()
last_update_step = variable_scope.get_variable(
'last_mask_update_step', dtype=np.int32)
return last_update_step
def _exists_in_do_not_prune_list(self, tensor_name):
do_not_prune_list = self._spec.do_not_prune
if not do_not_prune_list[0]:
return False
for layer_name in do_not_prune_list:
if tensor_name.find(layer_name) != -1:
return True
return False
def _update_mask(self, weights, threshold):
"""Updates the mask for a given weight tensor.
This functions first computes the cdf of the weight tensor, and estimates
the threshold value such that 'desired_sparsity' fraction of weights
have magnitude less than the threshold.
Args:
weights: The weight tensor that needs to be masked.
threshold: The current threshold value. The function will compute a new
threshold and return the exponential moving average using the current
value of threshold
Returns:
new_threshold: The new value of the threshold based on weights, and
desired_sparsity
new_mask: A n-D numpy array containing 0 or 1 to indicate which of the
values in weights falls below the threshold
Raises:
ValueError: if sparsity is not defined
"""
if self._sparsity is None:
raise ValueError('Sparsity variable undefined')
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(weights)
max_value = math_ops.reduce_max(abs_weights)
histogram = _histogram(
abs_weights, [0.0, max_value],
nbins=self._spec.nbins,
dtype=np.float32)
cdf = math_ops.cumsum(histogram)
norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram))
current_threshold = math_ops.multiply(
math_ops.div(
math_ops.reduce_sum(
math_ops.cast(
math_ops.less(norm_cdf, self._sparsity), np.float32)),
float(self._spec.nbins)), max_value)
smoothed_threshold = math_ops.add_n([
math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
math_ops.multiply(threshold, self._spec.threshold_decay)
])
new_mask = math_ops.cast(
math_ops.greater(abs_weights, smoothed_threshold), np.float32)
return smoothed_threshold, new_mask
def _get_mask_assign_ops(self):
# Make sure the assignment ops have not already been added to the list
if self._assign_ops:
raise ValueError(
'Assign op list not empty. _get_mask_assign_ops() called twice?')
masks = get_masks()
weights = get_weights()
thresholds = get_thresholds()
if len(masks) != len(thresholds):
raise ValueError(
'Number of masks %s and number of thresholds %s mismatch' %
(len(masks), len(thresholds)))
for index, mask in enumerate(masks):
threshold = thresholds[index]
weight = weights[index] if self._partitioner is None else weights[
index].as_tensor()
if self._spec.do_not_prune:
if self._exists_in_do_not_prune_list(mask.name):
continue
new_threshold, new_mask = self._update_mask(weight, threshold)
self._assign_ops.append(_variable_assign(threshold, new_threshold))
self._assign_ops.append(
_variable_assign(mask, new_mask) if self._partitioner is None else
_partitioned_variable_assign(mask, new_mask))
def mask_update_op(self):
with ops.name_scope(self._spec.name):
if not self._assign_ops:
self._get_mask_assign_ops()
with ops.control_dependencies([
state_ops.assign(
self._last_update_step,
self._global_step,
name='last_mask_update_step_assign')
]):
with ops.control_dependencies(self._assign_ops):
logging.info('Updating masks.')
return control_flow_ops.no_op('mask_update')
def conditional_mask_update_op(self):
def maybe_update_masks():
with ops.name_scope(self._spec.name):
is_step_within_pruning_range = math_ops.logical_and(
math_ops.greater_equal(self._global_step,
self._spec.begin_pruning_step),
# If end_pruning_step is negative, keep pruning forever!
math_ops.logical_or(
math_ops.less_equal(self._global_step,
self._spec.end_pruning_step),
math_ops.less(self._spec.end_pruning_step, 0)))
is_pruning_step = math_ops.less_equal(
math_ops.add(self._last_update_step, self._spec.pruning_frequency),
self._global_step)
return math_ops.logical_and(is_step_within_pruning_range,
is_pruning_step)
def mask_update_op():
return self.mask_update_op()
def no_update_op():
return control_flow_ops.no_op()
return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
no_update_op)
def add_pruning_summaries(self):
"""Adds summaries for this pruning spec.
Args: none
Returns: none
"""
with ops.name_scope(self._spec.name + '_summaries'):
summary.scalar('sparsity', self._sparsity)
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
for index, mask in enumerate(masks):
if not self._exists_in_do_not_prune_list(mask.name):
summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
summary.scalar(thresholds[index].op.name + '/threshold',
thresholds[index])
def print_hparams(self):
logging.info(self._spec.to_json())

View File

@ -0,0 +1,162 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the key functions in pruning library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
class PruningHParamsTest(test.TestCase):
PARAM_LIST = [
"name=test", "threshold_decay=0.9", "pruning_frequency=10",
"do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100",
"target_sparsity=0.9"
]
TEST_HPARAMS = ",".join(PARAM_LIST)
def setUp(self):
super(PruningHParamsTest, self).setUp()
# Add global step variable to the graph
self.global_step = training_util.get_or_create_global_step()
# Add sparsity
self.sparsity = variables.Variable(0.5, name="sparsity")
# Parse hparams
self.pruning_hparams = pruning.get_pruning_hparams().parse(
self.TEST_HPARAMS)
def testInit(self):
p = pruning.Pruning(self.pruning_hparams)
self.assertEqual(p._spec.name, "test")
self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
self.assertEqual(p._spec.pruning_frequency, 10)
self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"])
self.assertEqual(p._spec.sparsity_function_end_step, 100)
self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
def testInitWithExternalSparsity(self):
with self.test_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
def testInitWithVariableReuse(self):
with self.test_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
p_copy = pruning.Pruning(
spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
class PruningTest(test.TestCase):
def setUp(self):
super(PruningTest, self).setUp()
self.global_step = training_util.get_or_create_global_step()
def testCreateMask2D(self):
width = 10
height = 20
with self.test_session():
weights = variables.Variable(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
variable_scope.get_variable_scope())
variables.global_variables_initializer().run()
weights_val = weights.eval()
masked_weights_val = masked_weights.eval()
self.assertAllEqual(weights_val, masked_weights_val)
def testUpdateSingleMask(self):
with self.test_session() as session:
weights = variables.Variable(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
sparsity = variables.Variable(0.5, name="sparsity")
p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
variables.global_variables_initializer().run()
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
def testPartitionedVariableMasking(self):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
with self.test_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
sparsity = variables.Variable(0.5, name="Sparsity")
weights = variable_scope.get_variable(
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
masked_weights = pruning.apply_mask(
weights, scope=variable_scope.get_variable_scope())
p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
variables.global_variables_initializer().run()
masked_weights_val = masked_weights.eval()
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
def testConditionalMaskUpdate(self):
param_list = [
"pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
weights = variables.Variable(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
sparsity = variables.Variable(0.00, name="sparsity")
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.conditional_mask_update_op()
sparsity_val = math_ops.linspace(0.0, 0.9, 10)
increment_global_step = state_ops.assign_add(self.global_step, 1)
non_zero_count = []
with self.test_session() as session:
variables.global_variables_initializer().run()
for i in range(10):
session.run(state_ops.assign(sparsity, sparsity_val[i]))
session.run(mask_update_op)
session.run(increment_global_step)
non_zero_count.append(np.count_nonzero(masked_weights.eval()))
# Weights pruned at steps 0,2,4,and,6
expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
self.assertAllEqual(expected_non_zero_count, non_zero_count)
if __name__ == "__main__":
test.main()