Moving model_pruning library to tf.contrib
PiperOrigin-RevId: 174214419
This commit is contained in:
parent
693325c832
commit
7ece1c0b8e
@ -413,6 +413,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/makefile:all_files",
|
"//tensorflow/contrib/makefile:all_files",
|
||||||
"//tensorflow/contrib/meta_graph_transform:all_files",
|
"//tensorflow/contrib/meta_graph_transform:all_files",
|
||||||
"//tensorflow/contrib/metrics:all_files",
|
"//tensorflow/contrib/metrics:all_files",
|
||||||
|
"//tensorflow/contrib/model_pruning:all_files",
|
||||||
"//tensorflow/contrib/mpi_collectives:all_files",
|
"//tensorflow/contrib/mpi_collectives:all_files",
|
||||||
"//tensorflow/contrib/ndlstm:all_files",
|
"//tensorflow/contrib/ndlstm:all_files",
|
||||||
"//tensorflow/contrib/nearest_neighbor:all_files",
|
"//tensorflow/contrib/nearest_neighbor:all_files",
|
||||||
|
@ -57,6 +57,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
||||||
"//tensorflow/contrib/meta_graph_transform",
|
"//tensorflow/contrib/meta_graph_transform",
|
||||||
"//tensorflow/contrib/metrics:metrics_py",
|
"//tensorflow/contrib/metrics:metrics_py",
|
||||||
|
"//tensorflow/contrib/model_pruning",
|
||||||
"//tensorflow/contrib/nccl:nccl_py",
|
"//tensorflow/contrib/nccl:nccl_py",
|
||||||
"//tensorflow/contrib/ndlstm",
|
"//tensorflow/contrib/ndlstm",
|
||||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
|
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
|
||||||
|
@ -51,6 +51,7 @@ from tensorflow.contrib import lookup
|
|||||||
from tensorflow.contrib import losses
|
from tensorflow.contrib import losses
|
||||||
from tensorflow.contrib import memory_stats
|
from tensorflow.contrib import memory_stats
|
||||||
from tensorflow.contrib import metrics
|
from tensorflow.contrib import metrics
|
||||||
|
from tensorflow.contrib import model_pruning
|
||||||
from tensorflow.contrib import nccl
|
from tensorflow.contrib import nccl
|
||||||
from tensorflow.contrib import nn
|
from tensorflow.contrib import nn
|
||||||
from tensorflow.contrib import opt
|
from tensorflow.contrib import opt
|
||||||
|
@ -518,6 +518,11 @@ add_python_module("tensorflow/contrib/metrics/python")
|
|||||||
add_python_module("tensorflow/contrib/metrics/python/kernel_tests")
|
add_python_module("tensorflow/contrib/metrics/python/kernel_tests")
|
||||||
add_python_module("tensorflow/contrib/metrics/python/metrics")
|
add_python_module("tensorflow/contrib/metrics/python/metrics")
|
||||||
add_python_module("tensorflow/contrib/metrics/python/ops")
|
add_python_module("tensorflow/contrib/metrics/python/ops")
|
||||||
|
add_python_module("tensorflow/contrib/model_pruning")
|
||||||
|
add_python_module("tensorflow/contrib/model_pruning/examples")
|
||||||
|
add_python_module("tensorflow/contrib/model_pruning/examples/cifar10")
|
||||||
|
add_python_module("tensorflow/contrib/model_pruning/python")
|
||||||
|
add_python_module("tensorflow/contrib/model_pruning/python/layers")
|
||||||
add_python_module("tensorflow/contrib/ndlstm")
|
add_python_module("tensorflow/contrib/ndlstm")
|
||||||
add_python_module("tensorflow/contrib/ndlstm/python")
|
add_python_module("tensorflow/contrib/ndlstm/python")
|
||||||
add_python_module("tensorflow/contrib/nn")
|
add_python_module("tensorflow/contrib/nn")
|
||||||
|
139
tensorflow/contrib/model_pruning/BUILD
Normal file
139
tensorflow/contrib/model_pruning/BUILD
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "core_layers",
|
||||||
|
srcs = ["python/layers/core_layers.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:layers",
|
||||||
|
"//tensorflow/python:ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "layers",
|
||||||
|
srcs = ["python/layers/layers.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":core_layers",
|
||||||
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "layers_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/layers/layers_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":layers",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "learning",
|
||||||
|
srcs = ["python/learning.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/slim",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "rnn_cells",
|
||||||
|
srcs = ["python/layers/rnn_cells.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":core_layers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "pruning",
|
||||||
|
srcs = ["python/pruning.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":core_layers",
|
||||||
|
"//tensorflow/contrib/training:training_py",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "pruning_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/pruning_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":pruning",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "rnn_cells_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/layers/rnn_cells_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":pruning",
|
||||||
|
":rnn_cells",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "init_py",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Top-level library
|
||||||
|
py_library(
|
||||||
|
name = "model_pruning",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":init_py",
|
||||||
|
":layers",
|
||||||
|
":learning",
|
||||||
|
":pruning",
|
||||||
|
":rnn_cells",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
195
tensorflow/contrib/model_pruning/README.md
Normal file
195
tensorflow/contrib/model_pruning/README.md
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Model pruning: Training tensorflow models to have masked connections
|
||||||
|
|
||||||
|
This document describes the API that facilitates magnitude-based pruning of
|
||||||
|
neural network's weight tensors. The API helps inject necessary tensorflow op
|
||||||
|
into the training graph so the model can be pruned while it is being trained.
|
||||||
|
|
||||||
|
### Model creation
|
||||||
|
|
||||||
|
The first step involves adding mask and threshold variables to the layers that
|
||||||
|
need to undergo pruning. The variable mask is the same shape as the layer's
|
||||||
|
weight tensor and determines which of the weights participate in the forward
|
||||||
|
execution of the graph. This can be achieved by wrapping the weight tensor of
|
||||||
|
the layer with the `apply_mask` function provided in
|
||||||
|
[pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py).
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding)
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates a convolutional layer with additional variables mask and threshold
|
||||||
|
as shown below: 
|
||||||
|
|
||||||
|
Alternatively, the API also provides variant of tensorflow layers with these
|
||||||
|
auxiliary variables built-in (see
|
||||||
|
[layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers))
|
||||||
|
. Layers currently supported:
|
||||||
|
|
||||||
|
* [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83)
|
||||||
|
|
||||||
|
* [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241)
|
||||||
|
|
||||||
|
* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
|
||||||
|
|
||||||
|
### Adding pruning ops to the training graph
|
||||||
|
|
||||||
|
The pruning library allows for specification of the following hyper parameters:
|
||||||
|
|
||||||
|
| Hyperparameter | Type | Default | Description |
|
||||||
|
| ---------------------------- | ------- | ------------- | -------------- |
|
||||||
|
| name | string | model_pruning | Name of the |
|
||||||
|
: : : : pruning :
|
||||||
|
: : : : specification. :
|
||||||
|
: : : : Used for :
|
||||||
|
: : : : adding :
|
||||||
|
: : : : summaries and :
|
||||||
|
: : : : ops under a :
|
||||||
|
: : : : common :
|
||||||
|
: : : : tensorflow :
|
||||||
|
: : : : name_scope :
|
||||||
|
| begin_pruning_step | integer | 0 | The global |
|
||||||
|
: : : : step at which :
|
||||||
|
: : : : to begin :
|
||||||
|
: : : : pruning :
|
||||||
|
| end_pruning_step | integer | -1 | The global |
|
||||||
|
: : : : step at which :
|
||||||
|
: : : : to terminate :
|
||||||
|
: : : : pruning. :
|
||||||
|
: : : : Defaults to -1 :
|
||||||
|
: : : : implying that :
|
||||||
|
: : : : pruning :
|
||||||
|
: : : : continues till :
|
||||||
|
: : : : the training :
|
||||||
|
: : : : stops :
|
||||||
|
| do_not_prune | list of | [""] | list of layers |
|
||||||
|
: : strings : : that are not :
|
||||||
|
: : : : pruned :
|
||||||
|
| threshold_decay | float | 0.9 | The decay |
|
||||||
|
: : : : factor to use :
|
||||||
|
: : : : for :
|
||||||
|
: : : : exponential :
|
||||||
|
: : : : decay of the :
|
||||||
|
: : : : thresholds :
|
||||||
|
| pruning_frequency | integer | 10 | How often |
|
||||||
|
: : : : should the :
|
||||||
|
: : : : masks be :
|
||||||
|
: : : : updated? (in # :
|
||||||
|
: : : : of :
|
||||||
|
: : : : global_steps). :
|
||||||
|
| nbins | integer | 255 | Number of bins |
|
||||||
|
: : : : to use for :
|
||||||
|
: : : : histogram :
|
||||||
|
: : : : computation :
|
||||||
|
| initial_sparsity | float | 0.0 | Initial |
|
||||||
|
: : : : sparsity value :
|
||||||
|
| target_sparsity | float | 0.5 | Target |
|
||||||
|
: : : : sparsity value :
|
||||||
|
| sparsity_function_begin_step | integer | 0 | The global |
|
||||||
|
: : : : step at this :
|
||||||
|
: : : : which the :
|
||||||
|
: : : : gradual :
|
||||||
|
: : : : sparsity :
|
||||||
|
: : : : function :
|
||||||
|
: : : : begins to take :
|
||||||
|
: : : : effect :
|
||||||
|
| sparsity_function_end_step | integer | 100 | The global |
|
||||||
|
: : : : step used as :
|
||||||
|
: : : : the end point :
|
||||||
|
: : : : for the :
|
||||||
|
: : : : gradual :
|
||||||
|
: : : : sparsity :
|
||||||
|
: : : : function :
|
||||||
|
| sparsity_function_exponent | float | 3.0 | exponent = 1 |
|
||||||
|
: : : : is linearly :
|
||||||
|
: : : : varying :
|
||||||
|
: : : : sparsity :
|
||||||
|
: : : : between :
|
||||||
|
: : : : initial and :
|
||||||
|
: : : : final. :
|
||||||
|
: : : : exponent > 1 :
|
||||||
|
: : : : varies more :
|
||||||
|
: : : : slowly towards :
|
||||||
|
: : : : the end than :
|
||||||
|
: : : : the beginning :
|
||||||
|
|
||||||
|
The sparsity $$s_t$$ at global step $$t$$ is given by:
|
||||||
|
|
||||||
|
$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$
|
||||||
|
|
||||||
|
The interval between sparsity_function_begin_step and sparsity_function_end_step
|
||||||
|
is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
|
||||||
|
t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
|
||||||
|
is the sparsity_function_begin_step. In this equation, the
|
||||||
|
sparsity_function_exponent is set to 3.
|
||||||
|
### Adding pruning ops to the training graph
|
||||||
|
|
||||||
|
The final step involves adding ops to the training graph that monitors the
|
||||||
|
distribution of the layer's weight magnitudes and determines the layer threshold
|
||||||
|
such masking all the weights below this threshold achieves the sparsity level
|
||||||
|
desired for the current training step. This can be achieved as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.app.flags.DEFINE_string(
|
||||||
|
'pruning_hparams', '',
|
||||||
|
"""Comma separated list of pruning-related hyperparameters""")
|
||||||
|
|
||||||
|
with tf.graph.as_default():
|
||||||
|
|
||||||
|
# Create global step variable
|
||||||
|
global_step = tf.train.get_global_step()
|
||||||
|
|
||||||
|
# Parse pruning hyperparameters
|
||||||
|
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
|
||||||
|
|
||||||
|
# Create a pruning object using the pruning specification
|
||||||
|
p = pruning.Pruning(pruning_hparams, global_step=global_step)
|
||||||
|
|
||||||
|
# Add conditional mask update op. Executing this op will update all
|
||||||
|
# the masks in the graph if the current global step is in the range
|
||||||
|
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
|
||||||
|
mask_update_op = p.conditional_mask_update_op()
|
||||||
|
|
||||||
|
# Add summaries to keep track of the sparsity in different layers during training
|
||||||
|
p.add_pruning_summaries()
|
||||||
|
|
||||||
|
with tf.train.MonitoredTrainingSession(...) as mon_sess:
|
||||||
|
# Run the usual training op in the tf session
|
||||||
|
mon_sess.run(train_op)
|
||||||
|
|
||||||
|
# Update the masks by running the mask_update_op
|
||||||
|
mon_sess.run(mask_update_op)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Pruning and training deep CNNs on the cifar10 dataset
|
||||||
|
|
||||||
|
Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
|
||||||
|
network architecture, setting up inputs etc. The additional changes needed to
|
||||||
|
incorporate pruning are captured in the following:
|
||||||
|
|
||||||
|
* [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py)
|
||||||
|
creates a deep CNN with the same architecture, but adds mask and threshold
|
||||||
|
variables for each of the weight tensors in the convolutional and
|
||||||
|
locally-connected layers.
|
||||||
|
|
||||||
|
* [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py)
|
||||||
|
add pruning ops to the training graph as described above.
|
||||||
|
|
||||||
|
To train the pruned version of cifar10:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ examples_dir=contrib/model_pruning/examples
|
||||||
|
$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
|
||||||
|
$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
|
||||||
|
```
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO(suyoggupta): Add figures showing the sparsity function, sparsity for
|
||||||
|
different layers etc.
|
46
tensorflow/contrib/model_pruning/__init__.py
Normal file
46
tensorflow/contrib/model_pruning/__init__.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Model pruning implementation in tensorflow."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers.layers import masked_conv2d
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers.layers import masked_convolution
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers.layers import masked_fully_connected
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedLSTMCell
|
||||||
|
from tensorflow.contrib.model_pruning.python.learning import train
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import apply_mask
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import get_masks
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import get_thresholds
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import get_weights
|
||||||
|
from tensorflow.contrib.model_pruning.python.pruning import Pruning
|
||||||
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
_allowed_symbols = [
|
||||||
|
'masked_convolution', 'masked_conv2d', 'masked_fully_connected',
|
||||||
|
'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask',
|
||||||
|
'get_masked_weights', 'get_masks', 'get_thresholds', 'get_weights',
|
||||||
|
'get_weight_sparsity', 'Pruning'
|
||||||
|
]
|
||||||
|
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
77
tensorflow/contrib/model_pruning/examples/cifar10/BUILD
Normal file
77
tensorflow/contrib/model_pruning/examples/cifar10/BUILD
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# Description:
|
||||||
|
# Example TensorFlow models for CIFAR-10
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cifar10_input",
|
||||||
|
srcs = ["cifar10_input.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cifar10_pruning",
|
||||||
|
srcs = ["cifar10_pruning.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":cifar10_input",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "cifar10_eval",
|
||||||
|
srcs = [
|
||||||
|
"cifar10_eval.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":cifar10_pruning",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "cifar10_train",
|
||||||
|
srcs = [
|
||||||
|
"cifar10_train.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":cifar10_pruning",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Evaluation for CIFAR-10.
|
||||||
|
|
||||||
|
Accuracy:
|
||||||
|
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
|
||||||
|
of data) as judged by cifar10_eval.py.
|
||||||
|
|
||||||
|
Speed:
|
||||||
|
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
|
||||||
|
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
|
||||||
|
accuracy after 100K steps in 8 hours of training time.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Please see the tutorial and website for how to download the CIFAR-10
|
||||||
|
data set, compile the program and train the model.
|
||||||
|
|
||||||
|
http://tensorflow.org/tutorials/deep_cnn/
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
|
||||||
|
|
||||||
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
|
def eval_once(saver, summary_writer, top_k_op, summary_op):
|
||||||
|
"""Run Eval once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saver: Saver.
|
||||||
|
summary_writer: Summary writer.
|
||||||
|
top_k_op: Top K op.
|
||||||
|
summary_op: Summary op.
|
||||||
|
"""
|
||||||
|
with tf.Session() as sess:
|
||||||
|
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||||
|
if ckpt and ckpt.model_checkpoint_path:
|
||||||
|
# Restores from checkpoint
|
||||||
|
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||||
|
# Assuming model_checkpoint_path looks something like:
|
||||||
|
# /my-favorite-path/cifar10_train/model.ckpt-0,
|
||||||
|
# extract global_step from it.
|
||||||
|
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
|
||||||
|
else:
|
||||||
|
print('No checkpoint file found')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start the queue runners.
|
||||||
|
coord = tf.train.Coordinator()
|
||||||
|
try:
|
||||||
|
threads = []
|
||||||
|
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
|
||||||
|
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
|
||||||
|
start=True))
|
||||||
|
|
||||||
|
num_iter = int(math.ceil(FLAGS.num_examples / 128))
|
||||||
|
true_count = 0 # Counts the number of correct predictions.
|
||||||
|
total_sample_count = num_iter * 128
|
||||||
|
step = 0
|
||||||
|
while step < num_iter and not coord.should_stop():
|
||||||
|
predictions = sess.run([top_k_op])
|
||||||
|
true_count += np.sum(predictions)
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
# Compute precision @ 1.
|
||||||
|
precision = true_count / total_sample_count
|
||||||
|
print('%s: precision @ 1 = %.3f' % (datetime.datetime.now(), precision))
|
||||||
|
|
||||||
|
summary = tf.Summary()
|
||||||
|
summary.ParseFromString(sess.run(summary_op))
|
||||||
|
summary.value.add(tag='Precision @ 1', simple_value=precision)
|
||||||
|
summary_writer.add_summary(summary, global_step)
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
coord.request_stop(e)
|
||||||
|
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads, stop_grace_period_secs=10)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate():
|
||||||
|
"""Eval CIFAR-10 for a number of steps."""
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
# Get images and labels for CIFAR-10.
|
||||||
|
eval_data = FLAGS.eval_data == 'test'
|
||||||
|
images, labels = cifar10.inputs(eval_data=eval_data)
|
||||||
|
|
||||||
|
# Build a Graph that computes the logits predictions from the
|
||||||
|
# inference model.
|
||||||
|
logits = cifar10.inference(images)
|
||||||
|
|
||||||
|
# Calculate predictions.
|
||||||
|
top_k_op = tf.nn.in_top_k(logits, labels, 1)
|
||||||
|
|
||||||
|
# Restore the moving average version of the learned variables for eval.
|
||||||
|
variable_averages = tf.train.ExponentialMovingAverage(
|
||||||
|
cifar10.MOVING_AVERAGE_DECAY)
|
||||||
|
variables_to_restore = variable_averages.variables_to_restore()
|
||||||
|
saver = tf.train.Saver(variables_to_restore)
|
||||||
|
|
||||||
|
# Build the summary operation based on the TF collection of Summaries.
|
||||||
|
summary_op = tf.summary.merge_all()
|
||||||
|
|
||||||
|
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
eval_once(saver, summary_writer, top_k_op, summary_op)
|
||||||
|
if FLAGS.run_once:
|
||||||
|
break
|
||||||
|
time.sleep(FLAGS.eval_interval_secs)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv=None): # pylint: disable=unused-argument
|
||||||
|
cifar10.maybe_download_and_extract()
|
||||||
|
if tf.gfile.Exists(FLAGS.eval_dir):
|
||||||
|
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
|
||||||
|
tf.gfile.MakeDirs(FLAGS.eval_dir)
|
||||||
|
evaluate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='/tmp/cifar10_eval',
|
||||||
|
help='Directory where to write event logs.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_data',
|
||||||
|
type=str,
|
||||||
|
default='test',
|
||||||
|
help="""Either 'test' or 'train_eval'.""")
|
||||||
|
parser.add_argument(
|
||||||
|
'--checkpoint_dir',
|
||||||
|
type=str,
|
||||||
|
default='/tmp/cifar10_train',
|
||||||
|
help="""Directory where to read model checkpoints.""")
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_interval_secs',
|
||||||
|
type=int,
|
||||||
|
default=60 * 5,
|
||||||
|
help='How often to run the eval.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--num_examples',
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help='Number of examples to run.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--run_once',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='Whether to run eval only once.')
|
||||||
|
|
||||||
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -0,0 +1,256 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Routine for decoding the CIFAR-10 binary file format."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# Process images of this size. Note that this differs from the original CIFAR
|
||||||
|
# image size of 32 x 32. If one alters this number, then the entire model
|
||||||
|
# architecture will change and any model would need to be retrained.
|
||||||
|
IMAGE_SIZE = 24
|
||||||
|
|
||||||
|
# Global constants describing the CIFAR-10 data set.
|
||||||
|
NUM_CLASSES = 10
|
||||||
|
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
|
||||||
|
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
|
||||||
|
|
||||||
|
|
||||||
|
def read_cifar10(filename_queue):
|
||||||
|
"""Reads and parses examples from CIFAR10 data files.
|
||||||
|
|
||||||
|
Recommendation: if you want N-way read parallelism, call this function
|
||||||
|
N times. This will give you N independent Readers reading different
|
||||||
|
files & positions within those files, which will give better mixing of
|
||||||
|
examples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename_queue: A queue of strings with the filenames to read from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object representing a single example, with the following fields:
|
||||||
|
height: number of rows in the result (32)
|
||||||
|
width: number of columns in the result (32)
|
||||||
|
depth: number of color channels in the result (3)
|
||||||
|
key: a scalar string Tensor describing the filename & record number
|
||||||
|
for this example.
|
||||||
|
label: an int32 Tensor with the label in the range 0..9.
|
||||||
|
uint8image: a [height, width, depth] uint8 Tensor with the image data
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CIFAR10Record(object):
|
||||||
|
pass
|
||||||
|
result = CIFAR10Record()
|
||||||
|
|
||||||
|
# Dimensions of the images in the CIFAR-10 dataset.
|
||||||
|
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
|
||||||
|
# input format.
|
||||||
|
label_bytes = 1 # 2 for CIFAR-100
|
||||||
|
result.height = 32
|
||||||
|
result.width = 32
|
||||||
|
result.depth = 3
|
||||||
|
image_bytes = result.height * result.width * result.depth
|
||||||
|
# Every record consists of a label followed by the image, with a
|
||||||
|
# fixed number of bytes for each.
|
||||||
|
record_bytes = label_bytes + image_bytes
|
||||||
|
|
||||||
|
# Read a record, getting filenames from the filename_queue. No
|
||||||
|
# header or footer in the CIFAR-10 format, so we leave header_bytes
|
||||||
|
# and footer_bytes at their default of 0.
|
||||||
|
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
|
||||||
|
result.key, value = reader.read(filename_queue)
|
||||||
|
|
||||||
|
# Convert from a string to a vector of uint8 that is record_bytes long.
|
||||||
|
record_bytes = tf.decode_raw(value, tf.uint8)
|
||||||
|
|
||||||
|
# The first bytes represent the label, which we convert from uint8->int32.
|
||||||
|
result.label = tf.cast(
|
||||||
|
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
|
||||||
|
|
||||||
|
# The remaining bytes after the label represent the image, which we reshape
|
||||||
|
# from [depth * height * width] to [depth, height, width].
|
||||||
|
depth_major = tf.reshape(
|
||||||
|
tf.strided_slice(record_bytes, [label_bytes],
|
||||||
|
[label_bytes + image_bytes]),
|
||||||
|
[result.depth, result.height, result.width])
|
||||||
|
# Convert from [depth, height, width] to [height, width, depth].
|
||||||
|
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_image_and_label_batch(image, label, min_queue_examples,
|
||||||
|
batch_size, shuffle):
|
||||||
|
"""Construct a queued batch of images and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 3-D Tensor of [height, width, 3] of type.float32.
|
||||||
|
label: 1-D Tensor of type.int32
|
||||||
|
min_queue_examples: int32, minimum number of samples to retain
|
||||||
|
in the queue that provides of batches of examples.
|
||||||
|
batch_size: Number of images per batch.
|
||||||
|
shuffle: boolean indicating whether to use a shuffling queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Images. 4D tensor of [batch_size, height, width, 3] size.
|
||||||
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
|
"""
|
||||||
|
# Create a queue that shuffles the examples, and then
|
||||||
|
# read 'batch_size' images + labels from the example queue.
|
||||||
|
num_preprocess_threads = 16
|
||||||
|
if shuffle:
|
||||||
|
images, label_batch = tf.train.shuffle_batch(
|
||||||
|
[image, label],
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_threads=num_preprocess_threads,
|
||||||
|
capacity=min_queue_examples + 3 * batch_size,
|
||||||
|
min_after_dequeue=min_queue_examples)
|
||||||
|
else:
|
||||||
|
images, label_batch = tf.train.batch(
|
||||||
|
[image, label],
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_threads=num_preprocess_threads,
|
||||||
|
capacity=min_queue_examples + 3 * batch_size)
|
||||||
|
|
||||||
|
# Display the training images in the visualizer.
|
||||||
|
tf.summary.image('images', images)
|
||||||
|
|
||||||
|
return images, tf.reshape(label_batch, [batch_size])
|
||||||
|
|
||||||
|
|
||||||
|
def distorted_inputs(data_dir, batch_size):
|
||||||
|
"""Construct distorted input for CIFAR training using the Reader ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Path to the CIFAR-10 data directory.
|
||||||
|
batch_size: Number of images per batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||||
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
|
"""
|
||||||
|
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||||
|
for i in xrange(1, 6)]
|
||||||
|
for f in filenames:
|
||||||
|
if not tf.gfile.Exists(f):
|
||||||
|
raise ValueError('Failed to find file: ' + f)
|
||||||
|
|
||||||
|
# Create a queue that produces the filenames to read.
|
||||||
|
filename_queue = tf.train.string_input_producer(filenames)
|
||||||
|
|
||||||
|
# Read examples from files in the filename queue.
|
||||||
|
read_input = read_cifar10(filename_queue)
|
||||||
|
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
|
||||||
|
|
||||||
|
height = IMAGE_SIZE
|
||||||
|
width = IMAGE_SIZE
|
||||||
|
|
||||||
|
# Image processing for training the network. Note the many random
|
||||||
|
# distortions applied to the image.
|
||||||
|
|
||||||
|
# Randomly crop a [height, width] section of the image.
|
||||||
|
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
|
||||||
|
|
||||||
|
# Randomly flip the image horizontally.
|
||||||
|
distorted_image = tf.image.random_flip_left_right(distorted_image)
|
||||||
|
|
||||||
|
# Because these operations are not commutative, consider randomizing
|
||||||
|
# the order their operation.
|
||||||
|
distorted_image = tf.image.random_brightness(distorted_image,
|
||||||
|
max_delta=63)
|
||||||
|
distorted_image = tf.image.random_contrast(distorted_image,
|
||||||
|
lower=0.2, upper=1.8)
|
||||||
|
|
||||||
|
# Subtract off the mean and divide by the variance of the pixels.
|
||||||
|
float_image = tf.image.per_image_standardization(distorted_image)
|
||||||
|
|
||||||
|
# Set the shapes of tensors.
|
||||||
|
float_image.set_shape([height, width, 3])
|
||||||
|
read_input.label.set_shape([1])
|
||||||
|
|
||||||
|
# Ensure that the random shuffling has good mixing properties.
|
||||||
|
min_fraction_of_examples_in_queue = 0.4
|
||||||
|
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
|
||||||
|
min_fraction_of_examples_in_queue)
|
||||||
|
print ('Filling queue with %d CIFAR images before starting to train. '
|
||||||
|
'This will take a few minutes.' % min_queue_examples)
|
||||||
|
|
||||||
|
# Generate a batch of images and labels by building up a queue of examples.
|
||||||
|
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||||
|
min_queue_examples, batch_size,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
def inputs(eval_data, data_dir, batch_size):
|
||||||
|
"""Construct input for CIFAR evaluation using the Reader ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_data: bool, indicating if one should use the train or eval data set.
|
||||||
|
data_dir: Path to the CIFAR-10 data directory.
|
||||||
|
batch_size: Number of images per batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||||
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
|
"""
|
||||||
|
if not eval_data:
|
||||||
|
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||||
|
for i in xrange(1, 6)]
|
||||||
|
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||||
|
else:
|
||||||
|
filenames = [os.path.join(data_dir, 'test_batch.bin')]
|
||||||
|
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
|
||||||
|
|
||||||
|
for f in filenames:
|
||||||
|
if not tf.gfile.Exists(f):
|
||||||
|
raise ValueError('Failed to find file: ' + f)
|
||||||
|
|
||||||
|
# Create a queue that produces the filenames to read.
|
||||||
|
filename_queue = tf.train.string_input_producer(filenames)
|
||||||
|
|
||||||
|
# Read examples from files in the filename queue.
|
||||||
|
read_input = read_cifar10(filename_queue)
|
||||||
|
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
|
||||||
|
|
||||||
|
height = IMAGE_SIZE
|
||||||
|
width = IMAGE_SIZE
|
||||||
|
|
||||||
|
# Image processing for evaluation.
|
||||||
|
# Crop the central [height, width] of the image.
|
||||||
|
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
|
||||||
|
width, height)
|
||||||
|
|
||||||
|
# Subtract off the mean and divide by the variance of the pixels.
|
||||||
|
float_image = tf.image.per_image_standardization(resized_image)
|
||||||
|
|
||||||
|
# Set the shapes of tensors.
|
||||||
|
float_image.set_shape([height, width, 3])
|
||||||
|
read_input.label.set_shape([1])
|
||||||
|
|
||||||
|
# Ensure that the random shuffling has good mixing properties.
|
||||||
|
min_fraction_of_examples_in_queue = 0.4
|
||||||
|
min_queue_examples = int(num_examples_per_epoch *
|
||||||
|
min_fraction_of_examples_in_queue)
|
||||||
|
|
||||||
|
# Generate a batch of images and labels by building up a queue of examples.
|
||||||
|
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||||
|
min_queue_examples, batch_size,
|
||||||
|
shuffle=False)
|
@ -0,0 +1,395 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Builds the CIFAR-10 network with additional variables to support pruning.
|
||||||
|
|
||||||
|
Summary of available functions:
|
||||||
|
|
||||||
|
# Compute input images and labels for training. If you would like to run
|
||||||
|
# evaluations, use inputs() instead.
|
||||||
|
inputs, labels = distorted_inputs()
|
||||||
|
|
||||||
|
# Compute inference on the model inputs to make a prediction.
|
||||||
|
predictions = inference(inputs)
|
||||||
|
|
||||||
|
# Compute the total loss of the prediction with respect to the labels.
|
||||||
|
loss = loss(predictions, labels)
|
||||||
|
|
||||||
|
# Create a graph to run one step of training with respect to the loss.
|
||||||
|
train_op = train(loss, global_step)
|
||||||
|
"""
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from six.moves import urllib
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input
|
||||||
|
from tensorflow.contrib.model_pruning.python import pruning
|
||||||
|
|
||||||
|
# Global constants describing the CIFAR-10 data set.
|
||||||
|
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
|
||||||
|
NUM_CLASSES = cifar10_input.NUM_CLASSES
|
||||||
|
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||||
|
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
|
||||||
|
BATCH_SIZE = 128
|
||||||
|
DATA_DIR = '/tmp/cifar10_data'
|
||||||
|
|
||||||
|
# Constants describing the training process.
|
||||||
|
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
|
||||||
|
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
|
||||||
|
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
|
||||||
|
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
|
||||||
|
|
||||||
|
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
|
||||||
|
# to differentiate the operations. Note that this prefix is removed from the
|
||||||
|
# names of the summaries when visualizing a model.
|
||||||
|
TOWER_NAME = 'tower'
|
||||||
|
|
||||||
|
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
|
||||||
|
|
||||||
|
|
||||||
|
def _activation_summary(x):
|
||||||
|
"""Helper to create summaries for activations.
|
||||||
|
|
||||||
|
Creates a summary that provides a histogram of activations.
|
||||||
|
Creates a summary that measures the sparsity of activations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor
|
||||||
|
Returns:
|
||||||
|
nothing
|
||||||
|
"""
|
||||||
|
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
|
||||||
|
# session. This helps the clarity of presentation on tensorboard.
|
||||||
|
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
|
||||||
|
tf.summary.histogram(tensor_name + '/activations', x)
|
||||||
|
tf.summary.scalar(tensor_name + '/sparsity',
|
||||||
|
tf.nn.zero_fraction(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _variable_on_cpu(name, shape, initializer):
|
||||||
|
"""Helper to create a Variable stored on CPU memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the variable
|
||||||
|
shape: list of ints
|
||||||
|
initializer: initializer for Variable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Variable Tensor
|
||||||
|
"""
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
dtype = tf.float32
|
||||||
|
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
|
||||||
|
return var
|
||||||
|
|
||||||
|
|
||||||
|
def _variable_with_weight_decay(name, shape, stddev, wd):
|
||||||
|
"""Helper to create an initialized Variable with weight decay.
|
||||||
|
|
||||||
|
Note that the Variable is initialized with a truncated normal distribution.
|
||||||
|
A weight decay is added only if one is specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the variable
|
||||||
|
shape: list of ints
|
||||||
|
stddev: standard deviation of a truncated Gaussian
|
||||||
|
wd: add L2Loss weight decay multiplied by this float. If None, weight
|
||||||
|
decay is not added for this Variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Variable Tensor
|
||||||
|
"""
|
||||||
|
dtype = tf.float32
|
||||||
|
var = _variable_on_cpu(
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
|
||||||
|
if wd is not None:
|
||||||
|
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
|
||||||
|
tf.add_to_collection('losses', weight_decay)
|
||||||
|
return var
|
||||||
|
|
||||||
|
|
||||||
|
def distorted_inputs():
|
||||||
|
"""Construct distorted input for CIFAR training using the Reader ops.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||||
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no data_dir
|
||||||
|
"""
|
||||||
|
if not DATA_DIR:
|
||||||
|
raise ValueError('Please supply a data_dir')
|
||||||
|
data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
|
||||||
|
images, labels = cifar10_input.distorted_inputs(
|
||||||
|
data_dir=data_dir, batch_size=BATCH_SIZE)
|
||||||
|
return images, labels
|
||||||
|
|
||||||
|
|
||||||
|
def inputs(eval_data):
|
||||||
|
"""Construct input for CIFAR evaluation using the Reader ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_data: bool, indicating if one should use the train or eval data set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||||
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no data_dir
|
||||||
|
"""
|
||||||
|
if not DATA_DIR:
|
||||||
|
raise ValueError('Please supply a data_dir')
|
||||||
|
data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
|
||||||
|
images, labels = cifar10_input.inputs(
|
||||||
|
eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)
|
||||||
|
return images, labels
|
||||||
|
|
||||||
|
|
||||||
|
def inference(images):
|
||||||
|
"""Build the CIFAR-10 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: Images returned from distorted_inputs() or inputs().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits.
|
||||||
|
"""
|
||||||
|
# We instantiate all variables using tf.get_variable() instead of
|
||||||
|
# tf.Variable() in order to share variables across multiple GPU training runs.
|
||||||
|
# If we only ran this model on a single GPU, we could simplify this function
|
||||||
|
# by replacing all instances of tf.get_variable() with tf.Variable().
|
||||||
|
#
|
||||||
|
# While instantiating conv and local layers, we add mask and threshold
|
||||||
|
# variables to the layer by calling the pruning.apply_mask() function.
|
||||||
|
# Note that the masks are applied only to the weight tensors
|
||||||
|
# conv1
|
||||||
|
with tf.variable_scope('conv1') as scope:
|
||||||
|
kernel = _variable_with_weight_decay('weights',
|
||||||
|
shape=[5, 5, 3, 64],
|
||||||
|
stddev=5e-2,
|
||||||
|
wd=0.0)
|
||||||
|
|
||||||
|
conv = tf.nn.conv2d(
|
||||||
|
images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
|
||||||
|
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
|
||||||
|
pre_activation = tf.nn.bias_add(conv, biases)
|
||||||
|
conv1 = tf.nn.relu(pre_activation, name=scope.name)
|
||||||
|
_activation_summary(conv1)
|
||||||
|
|
||||||
|
# pool1
|
||||||
|
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
|
||||||
|
padding='SAME', name='pool1')
|
||||||
|
# norm1
|
||||||
|
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
||||||
|
name='norm1')
|
||||||
|
|
||||||
|
# conv2
|
||||||
|
with tf.variable_scope('conv2') as scope:
|
||||||
|
kernel = _variable_with_weight_decay('weights',
|
||||||
|
shape=[5, 5, 64, 64],
|
||||||
|
stddev=5e-2,
|
||||||
|
wd=0.0)
|
||||||
|
conv = tf.nn.conv2d(
|
||||||
|
norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
|
||||||
|
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
|
||||||
|
pre_activation = tf.nn.bias_add(conv, biases)
|
||||||
|
conv2 = tf.nn.relu(pre_activation, name=scope.name)
|
||||||
|
_activation_summary(conv2)
|
||||||
|
|
||||||
|
# norm2
|
||||||
|
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
||||||
|
name='norm2')
|
||||||
|
# pool2
|
||||||
|
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
|
||||||
|
strides=[1, 2, 2, 1], padding='SAME', name='pool2')
|
||||||
|
|
||||||
|
# local3
|
||||||
|
with tf.variable_scope('local3') as scope:
|
||||||
|
# Move everything into depth so we can perform a single matrix multiply.
|
||||||
|
reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
|
||||||
|
dim = reshape.get_shape()[1].value
|
||||||
|
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
|
||||||
|
stddev=0.04, wd=0.004)
|
||||||
|
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
|
||||||
|
local3 = tf.nn.relu(
|
||||||
|
tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
|
||||||
|
name=scope.name)
|
||||||
|
_activation_summary(local3)
|
||||||
|
|
||||||
|
# local4
|
||||||
|
with tf.variable_scope('local4') as scope:
|
||||||
|
weights = _variable_with_weight_decay('weights', shape=[384, 192],
|
||||||
|
stddev=0.04, wd=0.004)
|
||||||
|
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
|
||||||
|
local4 = tf.nn.relu(
|
||||||
|
tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
|
||||||
|
name=scope.name)
|
||||||
|
_activation_summary(local4)
|
||||||
|
|
||||||
|
# linear layer(WX + b),
|
||||||
|
# We don't apply softmax here because
|
||||||
|
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
|
||||||
|
# and performs the softmax internally for efficiency.
|
||||||
|
with tf.variable_scope('softmax_linear') as scope:
|
||||||
|
weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
|
||||||
|
stddev=1/192.0, wd=0.0)
|
||||||
|
biases = _variable_on_cpu('biases', [NUM_CLASSES],
|
||||||
|
tf.constant_initializer(0.0))
|
||||||
|
softmax_linear = tf.add(
|
||||||
|
tf.matmul(local4, pruning.apply_mask(weights, scope)),
|
||||||
|
biases,
|
||||||
|
name=scope.name)
|
||||||
|
_activation_summary(softmax_linear)
|
||||||
|
|
||||||
|
return softmax_linear
|
||||||
|
|
||||||
|
|
||||||
|
def loss(logits, labels):
|
||||||
|
"""Add L2Loss to all the trainable variables.
|
||||||
|
|
||||||
|
Add summary for "Loss" and "Loss/avg".
|
||||||
|
Args:
|
||||||
|
logits: Logits from inference().
|
||||||
|
labels: Labels from distorted_inputs or inputs(). 1-D tensor
|
||||||
|
of shape [batch_size]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss tensor of type float.
|
||||||
|
"""
|
||||||
|
# Calculate the average cross entropy loss across the batch.
|
||||||
|
labels = tf.cast(labels, tf.int64)
|
||||||
|
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=labels, logits=logits, name='cross_entropy_per_example')
|
||||||
|
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
|
||||||
|
tf.add_to_collection('losses', cross_entropy_mean)
|
||||||
|
|
||||||
|
# The total loss is defined as the cross entropy loss plus all of the weight
|
||||||
|
# decay terms (L2 loss).
|
||||||
|
return tf.add_n(tf.get_collection('losses'), name='total_loss')
|
||||||
|
|
||||||
|
|
||||||
|
def _add_loss_summaries(total_loss):
|
||||||
|
"""Add summaries for losses in CIFAR-10 model.
|
||||||
|
|
||||||
|
Generates moving average for all losses and associated summaries for
|
||||||
|
visualizing the performance of the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_loss: Total loss from loss().
|
||||||
|
Returns:
|
||||||
|
loss_averages_op: op for generating moving averages of losses.
|
||||||
|
"""
|
||||||
|
# Compute the moving average of all individual losses and the total loss.
|
||||||
|
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
|
||||||
|
losses = tf.get_collection('losses')
|
||||||
|
loss_averages_op = loss_averages.apply(losses + [total_loss])
|
||||||
|
|
||||||
|
# Attach a scalar summary to all individual losses and the total loss; do the
|
||||||
|
# same for the averaged version of the losses.
|
||||||
|
for l in losses + [total_loss]:
|
||||||
|
# Name each loss as '(raw)' and name the moving average version of the loss
|
||||||
|
# as the original loss name.
|
||||||
|
tf.summary.scalar(l.op.name + ' (raw)', l)
|
||||||
|
tf.summary.scalar(l.op.name, loss_averages.average(l))
|
||||||
|
|
||||||
|
return loss_averages_op
|
||||||
|
|
||||||
|
|
||||||
|
def train(total_loss, global_step):
|
||||||
|
"""Train CIFAR-10 model.
|
||||||
|
|
||||||
|
Create an optimizer and apply to all trainable variables. Add moving
|
||||||
|
average for all trainable variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_loss: Total loss from loss().
|
||||||
|
global_step: Integer Variable counting the number of training steps
|
||||||
|
processed.
|
||||||
|
Returns:
|
||||||
|
train_op: op for training.
|
||||||
|
"""
|
||||||
|
# Variables that affect learning rate.
|
||||||
|
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
|
||||||
|
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
|
||||||
|
|
||||||
|
# Decay the learning rate exponentially based on the number of steps.
|
||||||
|
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
|
||||||
|
global_step,
|
||||||
|
decay_steps,
|
||||||
|
LEARNING_RATE_DECAY_FACTOR,
|
||||||
|
staircase=True)
|
||||||
|
tf.summary.scalar('learning_rate', lr)
|
||||||
|
|
||||||
|
# Generate moving averages of all losses and associated summaries.
|
||||||
|
loss_averages_op = _add_loss_summaries(total_loss)
|
||||||
|
|
||||||
|
# Compute gradients.
|
||||||
|
with tf.control_dependencies([loss_averages_op]):
|
||||||
|
opt = tf.train.GradientDescentOptimizer(lr)
|
||||||
|
grads = opt.compute_gradients(total_loss)
|
||||||
|
|
||||||
|
# Apply gradients.
|
||||||
|
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
|
||||||
|
|
||||||
|
# Add histograms for trainable variables.
|
||||||
|
for var in tf.trainable_variables():
|
||||||
|
tf.summary.histogram(var.op.name, var)
|
||||||
|
|
||||||
|
# Add histograms for gradients.
|
||||||
|
for grad, var in grads:
|
||||||
|
if grad is not None:
|
||||||
|
tf.summary.histogram(var.op.name + '/gradients', grad)
|
||||||
|
|
||||||
|
# Track the moving averages of all trainable variables.
|
||||||
|
variable_averages = tf.train.ExponentialMovingAverage(
|
||||||
|
MOVING_AVERAGE_DECAY, global_step)
|
||||||
|
variables_averages_op = variable_averages.apply(tf.trainable_variables())
|
||||||
|
|
||||||
|
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
|
||||||
|
train_op = tf.no_op(name='train')
|
||||||
|
|
||||||
|
return train_op
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_download_and_extract():
|
||||||
|
"""Download and extract the tarball from Alex's website."""
|
||||||
|
dest_directory = DATA_DIR
|
||||||
|
if not os.path.exists(dest_directory):
|
||||||
|
os.makedirs(dest_directory)
|
||||||
|
filename = DATA_URL.split('/')[-1]
|
||||||
|
filepath = os.path.join(dest_directory, filename)
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
def _progress(count, block_size, total_size):
|
||||||
|
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
|
||||||
|
float(count * block_size) / float(total_size) * 100.0))
|
||||||
|
sys.stdout.flush()
|
||||||
|
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||||
|
print()
|
||||||
|
statinfo = os.stat(filepath)
|
||||||
|
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||||
|
|
||||||
|
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
@ -0,0 +1,159 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""A binary to train pruned CIFAR-10 using a single GPU.
|
||||||
|
|
||||||
|
Accuracy:
|
||||||
|
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
|
||||||
|
data) as judged by cifar10_eval.py when target sparsity in
|
||||||
|
cifar10_pruning_spec.pbtxt is set to zero
|
||||||
|
|
||||||
|
Results:
|
||||||
|
Sparsity | Accuracy after 150K steps
|
||||||
|
-------- | -------------------------
|
||||||
|
0% | 86%
|
||||||
|
50% | 86%
|
||||||
|
75% | TODO(suyoggupta)
|
||||||
|
90% | TODO(suyoggupta)
|
||||||
|
95% | 77%
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Please see the tutorial and website for how to download the CIFAR-10
|
||||||
|
data set, compile the program and train the model.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
|
||||||
|
from tensorflow.contrib.model_pruning.python import pruning
|
||||||
|
|
||||||
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""Train CIFAR-10 for a number of steps."""
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
global_step = tf.contrib.framework.get_or_create_global_step()
|
||||||
|
|
||||||
|
# Get images and labels for CIFAR-10.
|
||||||
|
images, labels = cifar10.distorted_inputs()
|
||||||
|
|
||||||
|
# Build a Graph that computes the logits predictions from the
|
||||||
|
# inference model.
|
||||||
|
logits = cifar10.inference(images)
|
||||||
|
|
||||||
|
# Calculate loss.
|
||||||
|
loss = cifar10.loss(logits, labels)
|
||||||
|
|
||||||
|
# Build a Graph that trains the model with one batch of examples and
|
||||||
|
# updates the model parameters.
|
||||||
|
train_op = cifar10.train(loss, global_step)
|
||||||
|
|
||||||
|
# Parse pruning hyperparameters
|
||||||
|
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
|
||||||
|
|
||||||
|
# Create a pruning object using the pruning hyperparameters
|
||||||
|
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
|
||||||
|
|
||||||
|
# Use the pruning_obj to add ops to the training graph to update the masks
|
||||||
|
# The conditional_mask_update_op will update the masks only when the
|
||||||
|
# training step is in [begin_pruning_step, end_pruning_step] specified in
|
||||||
|
# the pruning spec proto
|
||||||
|
mask_update_op = pruning_obj.conditional_mask_update_op()
|
||||||
|
|
||||||
|
# Use the pruning_obj to add summaries to the graph to track the sparsity
|
||||||
|
# of each of the layers
|
||||||
|
pruning_obj.add_pruning_summaries()
|
||||||
|
|
||||||
|
class _LoggerHook(tf.train.SessionRunHook):
|
||||||
|
"""Logs loss and runtime."""
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
self._step = -1
|
||||||
|
|
||||||
|
def before_run(self, run_context):
|
||||||
|
self._step += 1
|
||||||
|
self._start_time = time.time()
|
||||||
|
return tf.train.SessionRunArgs(loss) # Asks for loss value.
|
||||||
|
|
||||||
|
def after_run(self, run_context, run_values):
|
||||||
|
duration = time.time() - self._start_time
|
||||||
|
loss_value = run_values.results
|
||||||
|
if self._step % 10 == 0:
|
||||||
|
num_examples_per_step = 128
|
||||||
|
examples_per_sec = num_examples_per_step / duration
|
||||||
|
sec_per_batch = float(duration)
|
||||||
|
|
||||||
|
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
|
||||||
|
'sec/batch)')
|
||||||
|
print(format_str % (datetime.datetime.now(), self._step, loss_value,
|
||||||
|
examples_per_sec, sec_per_batch))
|
||||||
|
|
||||||
|
with tf.train.MonitoredTrainingSession(
|
||||||
|
checkpoint_dir=FLAGS.train_dir,
|
||||||
|
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
|
||||||
|
tf.train.NanTensorHook(loss),
|
||||||
|
_LoggerHook()],
|
||||||
|
config=tf.ConfigProto(
|
||||||
|
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
|
||||||
|
while not mon_sess.should_stop():
|
||||||
|
mon_sess.run(train_op)
|
||||||
|
# Update the masks
|
||||||
|
mon_sess.run(mask_update_op)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv=None): # pylint: disable=unused-argument
|
||||||
|
cifar10.maybe_download_and_extract()
|
||||||
|
if tf.gfile.Exists(FLAGS.train_dir):
|
||||||
|
tf.gfile.DeleteRecursively(FLAGS.train_dir)
|
||||||
|
tf.gfile.MakeDirs(FLAGS.train_dir)
|
||||||
|
train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--train_dir',
|
||||||
|
type=str,
|
||||||
|
default='/tmp/cifar10_train',
|
||||||
|
help='Directory where to write event logs and checkpoint.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--pruning_hparams',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help="""Comma separated list of pruning-related hyperparameters""")
|
||||||
|
parser.add_argument(
|
||||||
|
'--max_steps',
|
||||||
|
type=int,
|
||||||
|
default=1000000,
|
||||||
|
help='Number of batches to run.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--log_device_placement',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='Whether to log device placement.')
|
||||||
|
|
||||||
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
477
tensorflow/contrib/model_pruning/python/layers/core_layers.py
Normal file
477
tensorflow/contrib/model_pruning/python/layers/core_layers.py
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Contains the core layer classes for model pruning and its functional aliases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.layers import base
|
||||||
|
from tensorflow.python.layers import utils
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import standard_ops
|
||||||
|
|
||||||
|
MASK_COLLECTION = 'masks'
|
||||||
|
THRESHOLD_COLLECTION = 'thresholds'
|
||||||
|
MASKED_WEIGHT_COLLECTION = 'masked_weights'
|
||||||
|
WEIGHT_COLLECTION = 'kernel'
|
||||||
|
# The 'weights' part of the name is needed for the quantization library
|
||||||
|
# to recognize that the kernel should be quantized.
|
||||||
|
MASKED_WEIGHT_NAME = 'weights/masked_weight'
|
||||||
|
|
||||||
|
|
||||||
|
class _MaskedConv(base.Layer):
|
||||||
|
"""Abstract nD convolution layer (private, used as implementation base).
|
||||||
|
|
||||||
|
This layer creates a convolution kernel that is convolved
|
||||||
|
(actually cross-correlated) with the layer input to produce a tensor of
|
||||||
|
outputs. The weight tensor of this layer is masked.
|
||||||
|
If `use_bias` is True (and a `bias_initializer` is provided),
|
||||||
|
a bias vector is created and added to the outputs. Finally, if
|
||||||
|
`activation` is not `None`, it is applied to the outputs as well.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
|
||||||
|
filters: Integer, the dimensionality of the output space (i.e. the number
|
||||||
|
of filters in the convolution).
|
||||||
|
kernel_size: An integer or tuple/list of n integers, specifying the
|
||||||
|
length of the convolution window.
|
||||||
|
strides: An integer or tuple/list of n integers,
|
||||||
|
specifying the stride length of the convolution.
|
||||||
|
Specifying any stride value != 1 is incompatible with specifying
|
||||||
|
any `dilation_rate` value != 1.
|
||||||
|
padding: One of `"valid"` or `"same"` (case-insensitive).
|
||||||
|
data_format: A string, one of `channels_last` (default) or `channels_first`.
|
||||||
|
The ordering of the dimensions in the inputs.
|
||||||
|
`channels_last` corresponds to inputs with shape
|
||||||
|
`(batch, ..., channels)` while `channels_first` corresponds to
|
||||||
|
inputs with shape `(batch, channels, ...)`.
|
||||||
|
dilation_rate: An integer or tuple/list of n integers, specifying
|
||||||
|
the dilation rate to use for dilated convolution.
|
||||||
|
Currently, specifying any `dilation_rate` value != 1 is
|
||||||
|
incompatible with specifying any `strides` value != 1.
|
||||||
|
activation: Activation function. Set it to None to maintain a
|
||||||
|
linear activation.
|
||||||
|
use_bias: Boolean, whether the layer uses a bias.
|
||||||
|
kernel_initializer: An initializer for the convolution kernel.
|
||||||
|
bias_initializer: An initializer for the bias vector. If None, no bias will
|
||||||
|
be applied.
|
||||||
|
kernel_regularizer: Optional regularizer for the convolution kernel.
|
||||||
|
bias_regularizer: Optional regularizer for the bias vector.
|
||||||
|
activity_regularizer: Regularizer function for the output.
|
||||||
|
trainable: Boolean, if `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||||
|
name: A string, the name of the layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
rank,
|
||||||
|
filters,
|
||||||
|
kernel_size,
|
||||||
|
strides=1,
|
||||||
|
padding='valid',
|
||||||
|
data_format='channels_last',
|
||||||
|
dilation_rate=1,
|
||||||
|
activation=None,
|
||||||
|
use_bias=True,
|
||||||
|
kernel_initializer=None,
|
||||||
|
bias_initializer=init_ops.zeros_initializer(),
|
||||||
|
kernel_regularizer=None,
|
||||||
|
bias_regularizer=None,
|
||||||
|
activity_regularizer=None,
|
||||||
|
trainable=True,
|
||||||
|
name=None,
|
||||||
|
**kwargs):
|
||||||
|
super(_MaskedConv, self).__init__(
|
||||||
|
trainable=trainable,
|
||||||
|
name=name,
|
||||||
|
activity_regularizer=activity_regularizer,
|
||||||
|
**kwargs)
|
||||||
|
self.rank = rank
|
||||||
|
self.filters = filters
|
||||||
|
self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
|
||||||
|
self.strides = utils.normalize_tuple(strides, rank, 'strides')
|
||||||
|
self.padding = utils.normalize_padding(padding)
|
||||||
|
self.data_format = utils.normalize_data_format(data_format)
|
||||||
|
self.dilation_rate = utils.normalize_tuple(dilation_rate, rank,
|
||||||
|
'dilation_rate')
|
||||||
|
self.activation = activation
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.kernel_initializer = kernel_initializer
|
||||||
|
self.bias_initializer = bias_initializer
|
||||||
|
self.kernel_regularizer = kernel_regularizer
|
||||||
|
self.bias_regularizer = bias_regularizer
|
||||||
|
self.input_spec = base.InputSpec(ndim=self.rank + 2)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
input_shape = tensor_shape.TensorShape(input_shape)
|
||||||
|
channel_axis = 1 if self.data_format == 'channels_first' else -1
|
||||||
|
if input_shape[channel_axis].value is None:
|
||||||
|
raise ValueError('The channel dimension of the inputs '
|
||||||
|
'should be defined. Found `None`.')
|
||||||
|
input_dim = input_shape[channel_axis].value
|
||||||
|
kernel_shape = self.kernel_size + (input_dim, self.filters)
|
||||||
|
self.mask = self.add_variable(
|
||||||
|
name='mask',
|
||||||
|
shape=kernel_shape,
|
||||||
|
initializer=init_ops.ones_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
self.kernel = self.add_variable(
|
||||||
|
name='kernel',
|
||||||
|
shape=kernel_shape,
|
||||||
|
initializer=self.kernel_initializer,
|
||||||
|
regularizer=self.kernel_regularizer,
|
||||||
|
trainable=True,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
self.threshold = self.add_variable(
|
||||||
|
name='threshold',
|
||||||
|
shape=[],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
# Add masked_weights in the weights namescope so as to make it easier
|
||||||
|
# for the quantization library to add quant ops.
|
||||||
|
self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
|
||||||
|
MASKED_WEIGHT_NAME)
|
||||||
|
|
||||||
|
ops.add_to_collection(MASK_COLLECTION, self.mask)
|
||||||
|
ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
|
||||||
|
ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
|
||||||
|
ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
|
||||||
|
|
||||||
|
if self.use_bias:
|
||||||
|
self.bias = self.add_variable(
|
||||||
|
name='bias',
|
||||||
|
shape=(self.filters,),
|
||||||
|
initializer=self.bias_initializer,
|
||||||
|
regularizer=self.bias_regularizer,
|
||||||
|
trainable=True,
|
||||||
|
dtype=self.dtype)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
self.input_spec = base.InputSpec(
|
||||||
|
ndim=self.rank + 2, axes={channel_axis: input_dim})
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
outputs = nn.convolution(
|
||||||
|
input=inputs,
|
||||||
|
filter=self.masked_kernel,
|
||||||
|
dilation_rate=self.dilation_rate,
|
||||||
|
strides=self.strides,
|
||||||
|
padding=self.padding.upper(),
|
||||||
|
data_format=utils.convert_data_format(self.data_format, self.rank + 2))
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
if self.data_format == 'channels_first':
|
||||||
|
if self.rank == 1:
|
||||||
|
# nn.bias_add does not accept a 1D input tensor.
|
||||||
|
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
|
||||||
|
outputs += bias
|
||||||
|
if self.rank == 2:
|
||||||
|
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
|
||||||
|
if self.rank == 3:
|
||||||
|
# As of Mar 2017, direct addition is significantly slower than
|
||||||
|
# bias_add when computing gradients. To use bias_add, we collapse Z
|
||||||
|
# and Y into a single dimension to obtain a 4D input tensor.
|
||||||
|
outputs_shape = outputs.shape.as_list()
|
||||||
|
outputs_4d = array_ops.reshape(outputs, [
|
||||||
|
outputs_shape[0], outputs_shape[1],
|
||||||
|
outputs_shape[2] * outputs_shape[3], outputs_shape[4]
|
||||||
|
])
|
||||||
|
outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
|
||||||
|
outputs = array_ops.reshape(outputs_4d, outputs_shape)
|
||||||
|
else:
|
||||||
|
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
|
||||||
|
|
||||||
|
if self.activation is not None:
|
||||||
|
return self.activation(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _compute_output_shape(self, input_shape):
|
||||||
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
|
if self.data_format == 'channels_last':
|
||||||
|
space = input_shape[1:-1]
|
||||||
|
new_space = []
|
||||||
|
for i in range(len(space)):
|
||||||
|
new_dim = utils.conv_output_length(
|
||||||
|
space[i],
|
||||||
|
self.kernel_size[i],
|
||||||
|
padding=self.padding,
|
||||||
|
stride=self.strides[i],
|
||||||
|
dilation=self.dilation_rate[i])
|
||||||
|
new_space.append(new_dim)
|
||||||
|
return tensor_shape.TensorShape([input_shape[0]] + new_space +
|
||||||
|
[self.filters])
|
||||||
|
else:
|
||||||
|
space = input_shape[2:]
|
||||||
|
new_space = []
|
||||||
|
for i in range(len(space)):
|
||||||
|
new_dim = utils.conv_output_length(
|
||||||
|
space[i],
|
||||||
|
self.kernel_size[i],
|
||||||
|
padding=self.padding,
|
||||||
|
stride=self.strides[i],
|
||||||
|
dilation=self.dilation_rate[i])
|
||||||
|
new_space.append(new_dim)
|
||||||
|
return tensor_shape.TensorShape([input_shape[0], self.filters] +
|
||||||
|
new_space)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedConv2D(_MaskedConv):
|
||||||
|
"""2D convolution layer (e.g. spatial convolution over images).
|
||||||
|
|
||||||
|
This layer creates a convolution kernel that is convolved
|
||||||
|
(actually cross-correlated) with the layer input to produce a tensor of
|
||||||
|
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
|
||||||
|
a bias vector is created and added to the outputs. Finally, if
|
||||||
|
`activation` is not `None`, it is applied to the outputs as well.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
filters: Integer, the dimensionality of the output space (i.e. the number
|
||||||
|
of filters in the convolution).
|
||||||
|
kernel_size: An integer or tuple/list of 2 integers, specifying the
|
||||||
|
height and width of the 2D convolution window.
|
||||||
|
Can be a single integer to specify the same value for
|
||||||
|
all spatial dimensions.
|
||||||
|
strides: An integer or tuple/list of 2 integers,
|
||||||
|
specifying the strides of the convolution along the height and width.
|
||||||
|
Can be a single integer to specify the same value for
|
||||||
|
all spatial dimensions.
|
||||||
|
Specifying any stride value != 1 is incompatible with specifying
|
||||||
|
any `dilation_rate` value != 1.
|
||||||
|
padding: One of `"valid"` or `"same"` (case-insensitive).
|
||||||
|
data_format: A string, one of `channels_last` (default) or `channels_first`.
|
||||||
|
The ordering of the dimensions in the inputs.
|
||||||
|
`channels_last` corresponds to inputs with shape
|
||||||
|
`(batch, height, width, channels)` while `channels_first` corresponds to
|
||||||
|
inputs with shape `(batch, channels, height, width)`.
|
||||||
|
|
||||||
|
dilation_rate: An integer or tuple/list of 2 integers, specifying
|
||||||
|
the dilation rate to use for dilated convolution.
|
||||||
|
Can be a single integer to specify the same value for
|
||||||
|
all spatial dimensions.
|
||||||
|
Currently, specifying any `dilation_rate` value != 1 is
|
||||||
|
incompatible with specifying any stride value != 1.
|
||||||
|
activation: Activation function. Set it to None to maintain a
|
||||||
|
linear activation.
|
||||||
|
use_bias: Boolean, whether the layer uses a bias.
|
||||||
|
kernel_initializer: An initializer for the convolution kernel.
|
||||||
|
bias_initializer: An initializer for the bias vector. If None, no bias will
|
||||||
|
be applied.
|
||||||
|
kernel_regularizer: Optional regularizer for the convolution kernel.
|
||||||
|
bias_regularizer: Optional regularizer for the bias vector.
|
||||||
|
activity_regularizer: Regularizer function for the output.
|
||||||
|
trainable: Boolean, if `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||||
|
name: A string, the name of the layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
filters,
|
||||||
|
kernel_size,
|
||||||
|
strides=(1, 1),
|
||||||
|
padding='valid',
|
||||||
|
data_format='channels_last',
|
||||||
|
dilation_rate=(1, 1),
|
||||||
|
activation=None,
|
||||||
|
use_bias=True,
|
||||||
|
kernel_initializer=None,
|
||||||
|
bias_initializer=init_ops.zeros_initializer(),
|
||||||
|
kernel_regularizer=None,
|
||||||
|
bias_regularizer=None,
|
||||||
|
activity_regularizer=None,
|
||||||
|
trainable=True,
|
||||||
|
name=None,
|
||||||
|
**kwargs):
|
||||||
|
super(MaskedConv2D, self).__init__(
|
||||||
|
rank=2,
|
||||||
|
filters=filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
activation=activation,
|
||||||
|
use_bias=use_bias,
|
||||||
|
kernel_initializer=kernel_initializer,
|
||||||
|
bias_initializer=bias_initializer,
|
||||||
|
kernel_regularizer=kernel_regularizer,
|
||||||
|
bias_regularizer=bias_regularizer,
|
||||||
|
activity_regularizer=activity_regularizer,
|
||||||
|
trainable=trainable,
|
||||||
|
name=name,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFullyConnected(base.Layer):
|
||||||
|
"""Fully-connected layer class with masked weights.
|
||||||
|
|
||||||
|
This layer implements the operation:
|
||||||
|
`outputs = activation(inputs.kernel + bias)`
|
||||||
|
Where `activation` is the activation function passed as the `activation`
|
||||||
|
argument (if not `None`), `kernel` is a weights matrix created by the layer,
|
||||||
|
and `bias` is a bias vector created by the layer
|
||||||
|
(only if `use_bias` is `True`).
|
||||||
|
|
||||||
|
Note: if the input to the layer has a rank greater than 2, then it is
|
||||||
|
flattened prior to the initial matrix multiply by `kernel`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
units: Integer or Long, dimensionality of the output space.
|
||||||
|
activation: Activation function (callable). Set it to None to maintain a
|
||||||
|
linear activation.
|
||||||
|
use_bias: Boolean, whether the layer uses a bias.
|
||||||
|
kernel_initializer: Initializer function for the weight matrix.
|
||||||
|
bias_initializer: Initializer function for the bias.
|
||||||
|
kernel_regularizer: Regularizer function for the weight matrix.
|
||||||
|
bias_regularizer: Regularizer function for the bias.
|
||||||
|
activity_regularizer: Regularizer function for the output.
|
||||||
|
trainable: Boolean, if `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||||
|
name: String, the name of the layer. Layers with the same name will
|
||||||
|
share weights, but to avoid mistakes we require reuse=True in such cases.
|
||||||
|
reuse: Boolean, whether to reuse the weights of a previous layer
|
||||||
|
by the same name.
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
units: Python integer, dimensionality of the output space.
|
||||||
|
activation: Activation function (callable).
|
||||||
|
use_bias: Boolean, whether the layer uses a bias.
|
||||||
|
kernel_initializer: Initializer instance (or name) for the weight matrix.
|
||||||
|
bias_initializer: Initializer instance (or name) for the bias.
|
||||||
|
kernel_regularizer: Regularizer instance for the weight matrix (callable)
|
||||||
|
bias_regularizer: Regularizer instance for the bias (callable).
|
||||||
|
activity_regularizer: Regularizer instance for the output (callable)
|
||||||
|
kernel: Weight matrix (TensorFlow variable or tensor).
|
||||||
|
bias: Bias vector, if applicable (TensorFlow variable or tensor).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
units,
|
||||||
|
activation=None,
|
||||||
|
use_bias=True,
|
||||||
|
kernel_initializer=None,
|
||||||
|
bias_initializer=init_ops.zeros_initializer(),
|
||||||
|
kernel_regularizer=None,
|
||||||
|
bias_regularizer=None,
|
||||||
|
activity_regularizer=None,
|
||||||
|
trainable=True,
|
||||||
|
name=None,
|
||||||
|
**kwargs):
|
||||||
|
super(MaskedFullyConnected, self).__init__(
|
||||||
|
trainable=trainable,
|
||||||
|
name=name,
|
||||||
|
activity_regularizer=activity_regularizer,
|
||||||
|
**kwargs)
|
||||||
|
self.units = units
|
||||||
|
self.activation = activation
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.kernel_initializer = kernel_initializer
|
||||||
|
self.bias_initializer = bias_initializer
|
||||||
|
self.kernel_regularizer = kernel_regularizer
|
||||||
|
self.bias_regularizer = bias_regularizer
|
||||||
|
self.input_spec = base.InputSpec(min_ndim=2)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
input_shape = tensor_shape.TensorShape(input_shape)
|
||||||
|
if input_shape[-1].value is None:
|
||||||
|
raise ValueError('The last dimension of the inputs to `Dense` '
|
||||||
|
'should be defined. Found `None`.')
|
||||||
|
self.input_spec = base.InputSpec(
|
||||||
|
min_ndim=2, axes={-1: input_shape[-1].value})
|
||||||
|
|
||||||
|
self.kernel = self.add_variable(
|
||||||
|
'kernel',
|
||||||
|
shape=[input_shape[-1].value, self.units],
|
||||||
|
initializer=self.kernel_initializer,
|
||||||
|
regularizer=self.kernel_regularizer,
|
||||||
|
dtype=self.dtype,
|
||||||
|
trainable=True)
|
||||||
|
|
||||||
|
self.mask = self.add_variable(
|
||||||
|
name='mask',
|
||||||
|
shape=[input_shape[-1].value, self.units],
|
||||||
|
initializer=init_ops.ones_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
self.threshold = self.add_variable(
|
||||||
|
name='threshold',
|
||||||
|
shape=[],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
# Add masked_weights in the weights namescope so as to make it easier
|
||||||
|
# for the quantization library to add quant ops.
|
||||||
|
self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
|
||||||
|
MASKED_WEIGHT_NAME)
|
||||||
|
|
||||||
|
ops.add_to_collection(MASK_COLLECTION, self.mask)
|
||||||
|
ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
|
||||||
|
ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
|
||||||
|
ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
|
||||||
|
|
||||||
|
if self.use_bias:
|
||||||
|
self.bias = self.add_variable(
|
||||||
|
'bias',
|
||||||
|
shape=[
|
||||||
|
self.units,
|
||||||
|
],
|
||||||
|
initializer=self.bias_initializer,
|
||||||
|
regularizer=self.bias_regularizer,
|
||||||
|
dtype=self.dtype,
|
||||||
|
trainable=True)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
|
||||||
|
shape = inputs.get_shape().as_list()
|
||||||
|
output_shape = shape[:-1] + [self.units]
|
||||||
|
if len(output_shape) > 2:
|
||||||
|
# Broadcasting is required for the inputs.
|
||||||
|
outputs = standard_ops.tensordot(inputs, self.masked_kernel,
|
||||||
|
[[len(shape) - 1], [0]])
|
||||||
|
# Reshape the output back to the original ndim of the input.
|
||||||
|
outputs.set_shape(output_shape)
|
||||||
|
else:
|
||||||
|
outputs = standard_ops.matmul(inputs, self.masked_kernel)
|
||||||
|
if self.use_bias:
|
||||||
|
outputs = nn.bias_add(outputs, self.bias)
|
||||||
|
if self.activation is not None:
|
||||||
|
return self.activation(outputs) # pylint: disable=not-callable
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _compute_output_shape(self, input_shape):
|
||||||
|
input_shape = tensor_shape.TensorShape(input_shape)
|
||||||
|
input_shape = input_shape.with_rank_at_least(2)
|
||||||
|
if input_shape[-1].value is None:
|
||||||
|
raise ValueError(
|
||||||
|
'The innermost dimension of input_shape must be defined, but saw: %s'
|
||||||
|
% input_shape)
|
||||||
|
return input_shape[:-1].concatenate(self.units)
|
364
tensorflow/contrib/model_pruning/python/layers/layers.py
Normal file
364
tensorflow/contrib/model_pruning/python/layers/layers.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tensorflow layers with added variables for parameter masking.
|
||||||
|
|
||||||
|
Branched from tensorflow/contrib/layers/python/layers/layers.py
|
||||||
|
"""
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
|
from tensorflow.contrib.layers.python.layers import initializers
|
||||||
|
from tensorflow.contrib.layers.python.layers import utils
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
|
|
||||||
|
|
||||||
|
def _model_variable_getter(getter,
|
||||||
|
name,
|
||||||
|
shape=None,
|
||||||
|
dtype=None,
|
||||||
|
initializer=None,
|
||||||
|
regularizer=None,
|
||||||
|
trainable=True,
|
||||||
|
collections=None,
|
||||||
|
caching_device=None,
|
||||||
|
partitioner=None,
|
||||||
|
rename=None,
|
||||||
|
use_resource=None,
|
||||||
|
**_):
|
||||||
|
"""Getter that uses model_variable for compatibility with core layers."""
|
||||||
|
short_name = name.split('/')[-1]
|
||||||
|
if rename and short_name in rename:
|
||||||
|
name_components = name.split('/')
|
||||||
|
name_components[-1] = rename[short_name]
|
||||||
|
name = '/'.join(name_components)
|
||||||
|
return variables.model_variable(
|
||||||
|
name,
|
||||||
|
shape=shape,
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=initializer,
|
||||||
|
regularizer=regularizer,
|
||||||
|
collections=collections,
|
||||||
|
trainable=trainable,
|
||||||
|
caching_device=caching_device,
|
||||||
|
partitioner=partitioner,
|
||||||
|
custom_getter=getter,
|
||||||
|
use_resource=use_resource)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_variable_getter(rename=None):
|
||||||
|
"""Build a model variable getter that respects scope getter and renames."""
|
||||||
|
|
||||||
|
# VariableScope will nest the getters
|
||||||
|
def layer_variable_getter(getter, *args, **kwargs):
|
||||||
|
kwargs['rename'] = rename
|
||||||
|
return _model_variable_getter(getter, *args, **kwargs)
|
||||||
|
|
||||||
|
return layer_variable_getter
|
||||||
|
|
||||||
|
|
||||||
|
def _add_variable_to_collections(variable, collections_set, collections_name):
|
||||||
|
"""Adds variable (or all its parts) to all collections with that name."""
|
||||||
|
collections = utils.get_variable_collections(collections_set,
|
||||||
|
collections_name) or []
|
||||||
|
variables_list = [variable]
|
||||||
|
if isinstance(variable, tf_variables.PartitionedVariable):
|
||||||
|
variables_list = [v for v in variable]
|
||||||
|
for collection in collections:
|
||||||
|
for var in variables_list:
|
||||||
|
if var not in ops.get_collection(collection):
|
||||||
|
ops.add_to_collection(collection, var)
|
||||||
|
|
||||||
|
|
||||||
|
@add_arg_scope
|
||||||
|
def masked_convolution(inputs,
|
||||||
|
num_outputs,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding='SAME',
|
||||||
|
data_format=None,
|
||||||
|
rate=1,
|
||||||
|
activation_fn=nn.relu,
|
||||||
|
normalizer_fn=None,
|
||||||
|
normalizer_params=None,
|
||||||
|
weights_initializer=initializers.xavier_initializer(),
|
||||||
|
weights_regularizer=None,
|
||||||
|
biases_initializer=init_ops.zeros_initializer(),
|
||||||
|
biases_regularizer=None,
|
||||||
|
reuse=None,
|
||||||
|
variables_collections=None,
|
||||||
|
outputs_collections=None,
|
||||||
|
trainable=True,
|
||||||
|
scope=None):
|
||||||
|
"""Adds an 2D convolution followed by an optional batch_norm layer.
|
||||||
|
The layer creates a mask variable on top of the weight variable. The input to
|
||||||
|
the convolution operation is the elementwise multiplication of the mask
|
||||||
|
variable and the weigh
|
||||||
|
|
||||||
|
It is required that 1 <= N <= 3.
|
||||||
|
|
||||||
|
`convolution` creates a variable called `weights`, representing the
|
||||||
|
convolutional kernel, that is convolved (actually cross-correlated) with the
|
||||||
|
`inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
|
||||||
|
provided (such as `batch_norm`), it is then applied. Otherwise, if
|
||||||
|
`normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
|
||||||
|
variable would be created and added the activations. Finally, if
|
||||||
|
`activation_fn` is not `None`, it is applied to the activations as well.
|
||||||
|
|
||||||
|
Performs atrous convolution with input stride/dilation rate equal to `rate`
|
||||||
|
if a value > 1 for any dimension of `rate` is specified. In this case
|
||||||
|
`stride` values != 1 are not supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A Tensor of rank N+2 of shape
|
||||||
|
`[batch_size] + input_spatial_shape + [in_channels]` if data_format does
|
||||||
|
not start with "NC" (default), or
|
||||||
|
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
|
||||||
|
with "NC".
|
||||||
|
num_outputs: Integer, the number of output filters.
|
||||||
|
kernel_size: A sequence of N positive integers specifying the spatial
|
||||||
|
dimensions of of the filters. Can be a single integer to specify the same
|
||||||
|
value for all spatial dimensions.
|
||||||
|
stride: A sequence of N positive integers specifying the stride at which to
|
||||||
|
compute output. Can be a single integer to specify the same value for all
|
||||||
|
spatial dimensions. Specifying any `stride` value != 1 is incompatible
|
||||||
|
with specifying any `rate` value != 1.
|
||||||
|
padding: One of `"VALID"` or `"SAME"`.
|
||||||
|
data_format: A string or None. Specifies whether the channel dimension of
|
||||||
|
the `input` and output is the last dimension (default, or if `data_format`
|
||||||
|
does not start with "NC"), or the second dimension (if `data_format`
|
||||||
|
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||||
|
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
|
||||||
|
For N=3, the valid values are "NDHWC" (default) and "NCDHW".
|
||||||
|
rate: A sequence of N positive integers specifying the dilation rate to use
|
||||||
|
for atrous convolution. Can be a single integer to specify the same
|
||||||
|
value for all spatial dimensions. Specifying any `rate` value != 1 is
|
||||||
|
incompatible with specifying any `stride` value != 1.
|
||||||
|
activation_fn: Activation function. The default value is a ReLU function.
|
||||||
|
Explicitly set it to None to skip it and maintain a linear activation.
|
||||||
|
normalizer_fn: Normalization function to use instead of `biases`. If
|
||||||
|
`normalizer_fn` is provided then `biases_initializer` and
|
||||||
|
`biases_regularizer` are ignored and `biases` are not created nor added.
|
||||||
|
default set to None for no normalizer function
|
||||||
|
normalizer_params: Normalization function parameters.
|
||||||
|
weights_initializer: An initializer for the weights.
|
||||||
|
weights_regularizer: Optional regularizer for the weights.
|
||||||
|
biases_initializer: An initializer for the biases. If None skip biases.
|
||||||
|
biases_regularizer: Optional regularizer for the biases.
|
||||||
|
reuse: Whether or not the layer and its variables should be reused. To be
|
||||||
|
able to reuse the layer scope must be given.
|
||||||
|
variables_collections: Optional list of collections for all the variables or
|
||||||
|
a dictionary containing a different list of collection per variable.
|
||||||
|
outputs_collections: Collection to add the outputs.
|
||||||
|
trainable: If `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
scope: Optional scope for `variable_scope`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor representing the output of the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `data_format` is invalid.
|
||||||
|
ValueError: Both 'rate' and `stride` are not uniformly 1.
|
||||||
|
"""
|
||||||
|
if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
|
||||||
|
raise ValueError('Invalid data_format: %r' % (data_format,))
|
||||||
|
|
||||||
|
layer_variable_getter = _build_variable_getter({
|
||||||
|
'bias': 'biases',
|
||||||
|
'kernel': 'weights'
|
||||||
|
})
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
scope, 'Conv', [inputs], reuse=reuse,
|
||||||
|
custom_getter=layer_variable_getter) as sc:
|
||||||
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
|
input_rank = inputs.get_shape().ndims
|
||||||
|
|
||||||
|
if input_rank == 3:
|
||||||
|
raise ValueError('Sparse Convolution not supported for input with rank',
|
||||||
|
input_rank)
|
||||||
|
elif input_rank == 4:
|
||||||
|
layer_class = core.MaskedConv2D
|
||||||
|
elif input_rank == 5:
|
||||||
|
raise ValueError('Sparse Convolution not supported for input with rank',
|
||||||
|
input_rank)
|
||||||
|
else:
|
||||||
|
raise ValueError('Sparse Convolution not supported for input with rank',
|
||||||
|
input_rank)
|
||||||
|
|
||||||
|
if data_format is None or data_format == 'NHWC':
|
||||||
|
df = 'channels_last'
|
||||||
|
elif data_format == 'NCHW':
|
||||||
|
df = 'channels_first'
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported data fromat', data_format)
|
||||||
|
|
||||||
|
layer = layer_class(
|
||||||
|
filters=num_outputs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
strides=stride,
|
||||||
|
padding=padding,
|
||||||
|
data_format=df,
|
||||||
|
dilation_rate=rate,
|
||||||
|
activation=None,
|
||||||
|
use_bias=not normalizer_fn and biases_initializer,
|
||||||
|
kernel_initializer=weights_initializer,
|
||||||
|
bias_initializer=biases_initializer,
|
||||||
|
kernel_regularizer=weights_regularizer,
|
||||||
|
bias_regularizer=biases_regularizer,
|
||||||
|
activity_regularizer=None,
|
||||||
|
trainable=trainable,
|
||||||
|
name=sc.name,
|
||||||
|
dtype=inputs.dtype.base_dtype,
|
||||||
|
_scope=sc,
|
||||||
|
_reuse=reuse)
|
||||||
|
outputs = layer.apply(inputs)
|
||||||
|
|
||||||
|
# Add variables to collections.
|
||||||
|
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
|
||||||
|
if layer.use_bias:
|
||||||
|
_add_variable_to_collections(layer.bias, variables_collections, 'biases')
|
||||||
|
|
||||||
|
if normalizer_fn is not None:
|
||||||
|
normalizer_params = normalizer_params or {}
|
||||||
|
outputs = normalizer_fn(outputs, **normalizer_params)
|
||||||
|
|
||||||
|
if activation_fn is not None:
|
||||||
|
outputs = activation_fn(outputs)
|
||||||
|
return utils.collect_named_outputs(outputs_collections,
|
||||||
|
sc.original_name_scope, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
masked_conv2d = masked_convolution
|
||||||
|
|
||||||
|
|
||||||
|
@add_arg_scope
|
||||||
|
def masked_fully_connected(
|
||||||
|
inputs,
|
||||||
|
num_outputs,
|
||||||
|
activation_fn=nn.relu,
|
||||||
|
normalizer_fn=None,
|
||||||
|
normalizer_params=None,
|
||||||
|
weights_initializer=initializers.xavier_initializer(),
|
||||||
|
weights_regularizer=None,
|
||||||
|
biases_initializer=init_ops.zeros_initializer(),
|
||||||
|
biases_regularizer=None,
|
||||||
|
reuse=None,
|
||||||
|
variables_collections=None,
|
||||||
|
outputs_collections=None,
|
||||||
|
trainable=True,
|
||||||
|
scope=None):
|
||||||
|
"""Adds a sparse fully connected layer. The weight matrix is masked.
|
||||||
|
|
||||||
|
`fully_connected` creates a variable called `weights`, representing a fully
|
||||||
|
connected weight matrix, which is multiplied by the `inputs` to produce a
|
||||||
|
`Tensor` of hidden units. If a `normalizer_fn` is provided (such as
|
||||||
|
`batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
|
||||||
|
None and a `biases_initializer` is provided then a `biases` variable would be
|
||||||
|
created and added the hidden units. Finally, if `activation_fn` is not `None`,
|
||||||
|
it is applied to the hidden units as well.
|
||||||
|
|
||||||
|
Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
|
||||||
|
prior to the initial matrix multiply by `weights`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A tensor of at least rank 2 and static value for the last dimension;
|
||||||
|
i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
|
||||||
|
num_outputs: Integer or long, the number of output units in the layer.
|
||||||
|
activation_fn: Activation function. The default value is a ReLU function.
|
||||||
|
Explicitly set it to None to skip it and maintain a linear activation.
|
||||||
|
normalizer_fn: Normalization function to use instead of `biases`. If
|
||||||
|
`normalizer_fn` is provided then `biases_initializer` and
|
||||||
|
`biases_regularizer` are ignored and `biases` are not created nor added.
|
||||||
|
default set to None for no normalizer function
|
||||||
|
normalizer_params: Normalization function parameters.
|
||||||
|
weights_initializer: An initializer for the weights.
|
||||||
|
weights_regularizer: Optional regularizer for the weights.
|
||||||
|
biases_initializer: An initializer for the biases. If None skip biases.
|
||||||
|
biases_regularizer: Optional regularizer for the biases.
|
||||||
|
reuse: Whether or not the layer and its variables should be reused. To be
|
||||||
|
able to reuse the layer scope must be given.
|
||||||
|
variables_collections: Optional list of collections for all the variables or
|
||||||
|
a dictionary containing a different list of collections per variable.
|
||||||
|
outputs_collections: Collection to add the outputs.
|
||||||
|
trainable: If `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
scope: Optional scope for variable_scope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensor variable representing the result of the series of operations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If x has rank less than 2 or if its last dimension is not set.
|
||||||
|
"""
|
||||||
|
if not isinstance(num_outputs, six.integer_types):
|
||||||
|
raise ValueError('num_outputs should be int or long, got %s.' %
|
||||||
|
(num_outputs,))
|
||||||
|
|
||||||
|
layer_variable_getter = _build_variable_getter({
|
||||||
|
'bias': 'biases',
|
||||||
|
'kernel': 'weights'
|
||||||
|
})
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
scope,
|
||||||
|
'fully_connected', [inputs],
|
||||||
|
reuse=reuse,
|
||||||
|
custom_getter=layer_variable_getter) as sc:
|
||||||
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
|
layer = core.MaskedFullyConnected(
|
||||||
|
units=num_outputs,
|
||||||
|
activation=None,
|
||||||
|
use_bias=not normalizer_fn and biases_initializer,
|
||||||
|
kernel_initializer=weights_initializer,
|
||||||
|
bias_initializer=biases_initializer,
|
||||||
|
kernel_regularizer=weights_regularizer,
|
||||||
|
bias_regularizer=biases_regularizer,
|
||||||
|
activity_regularizer=None,
|
||||||
|
trainable=trainable,
|
||||||
|
name=sc.name,
|
||||||
|
dtype=inputs.dtype.base_dtype,
|
||||||
|
_scope=sc,
|
||||||
|
_reuse=reuse)
|
||||||
|
outputs = layer.apply(inputs)
|
||||||
|
|
||||||
|
# Add variables to collections.
|
||||||
|
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
|
||||||
|
if layer.bias is not None:
|
||||||
|
_add_variable_to_collections(layer.bias, variables_collections, 'biases')
|
||||||
|
|
||||||
|
# Apply normalizer function / layer.
|
||||||
|
if normalizer_fn is not None:
|
||||||
|
if not normalizer_params:
|
||||||
|
normalizer_params = {}
|
||||||
|
outputs = normalizer_fn(outputs, **normalizer_params)
|
||||||
|
|
||||||
|
if activation_fn is not None:
|
||||||
|
outputs = activation_fn(outputs)
|
||||||
|
|
||||||
|
return utils.collect_named_outputs(outputs_collections,
|
||||||
|
sc.original_name_scope, outputs)
|
139
tensorflow/contrib/model_pruning/python/layers/layers_test.py
Normal file
139
tensorflow/contrib/model_pruning/python/layers/layers_test.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for imagingvision.intelligence.tensorflow.model_pruning.layers."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import core_layers
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import layers
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedConvolutionLayerTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MaskedConvolutionLayerTest, self).setUp()
|
||||||
|
self.height, self.width = 7, 9
|
||||||
|
|
||||||
|
def testInvalidRank3(self):
|
||||||
|
input_tensor = array_ops.ones((self.height, self.width, 3))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'rank'):
|
||||||
|
layers.masked_conv2d(input_tensor, 32, 3)
|
||||||
|
|
||||||
|
def testInvalidRank5(self):
|
||||||
|
input_tensor = array_ops.ones((8, 8, self.height, self.width, 3))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'rank'):
|
||||||
|
layers.masked_conv2d(input_tensor, 32, 3)
|
||||||
|
|
||||||
|
def testSingleConvMaskAdded(self):
|
||||||
|
kernel_size = 3
|
||||||
|
input_depth, output_depth = 8, 32
|
||||||
|
input_tensor = array_ops.ones((8, self.height, self.width, input_depth))
|
||||||
|
layers.masked_conv2d(input_tensor, output_depth, kernel_size)
|
||||||
|
|
||||||
|
masks = ops.get_collection(core_layers.MASK_COLLECTION)
|
||||||
|
self.assertEqual(len(masks), 1)
|
||||||
|
self.assertListEqual(masks[0].get_shape().as_list(),
|
||||||
|
[kernel_size, kernel_size, input_depth, output_depth])
|
||||||
|
|
||||||
|
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
|
||||||
|
self.assertEqual(len(masked_weight), 1)
|
||||||
|
self.assertListEqual(masked_weight[0].get_shape().as_list(),
|
||||||
|
[kernel_size, kernel_size, input_depth, output_depth])
|
||||||
|
|
||||||
|
def testMultipleConvMaskAdded(self):
|
||||||
|
number_of_layers = 5
|
||||||
|
|
||||||
|
kernel_size = 3
|
||||||
|
base_depth = 4
|
||||||
|
depth_step = 7
|
||||||
|
|
||||||
|
input_tensor = array_ops.ones((8, self.height, self.width, base_depth))
|
||||||
|
|
||||||
|
top_layer = input_tensor
|
||||||
|
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
top_layer = layers.masked_conv2d(top_layer, base_depth +
|
||||||
|
(ix + 1) * depth_step, kernel_size)
|
||||||
|
|
||||||
|
masks = ops.get_collection(core_layers.MASK_COLLECTION)
|
||||||
|
self.assertEqual(len(masks), number_of_layers)
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
self.assertListEqual(masks[ix].get_shape().as_list(), [
|
||||||
|
kernel_size, kernel_size, base_depth + ix * depth_step,
|
||||||
|
base_depth + (ix + 1) * depth_step
|
||||||
|
])
|
||||||
|
|
||||||
|
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
|
||||||
|
self.assertEqual(len(masked_weight), number_of_layers)
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
|
||||||
|
kernel_size, kernel_size, base_depth + ix * depth_step,
|
||||||
|
base_depth + (ix + 1) * depth_step
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFullyConnectedLayerTest(test.TestCase):
|
||||||
|
|
||||||
|
def testSingleFCMaskAdded(self):
|
||||||
|
input_depth, output_depth = 8, 32
|
||||||
|
input_tensor = array_ops.ones((5, input_depth))
|
||||||
|
layers.masked_fully_connected(input_tensor, output_depth)
|
||||||
|
|
||||||
|
masks = ops.get_collection(core_layers.MASK_COLLECTION)
|
||||||
|
self.assertEqual(len(masks), 1)
|
||||||
|
self.assertListEqual(masks[0].get_shape().as_list(),
|
||||||
|
[input_depth, output_depth])
|
||||||
|
|
||||||
|
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
|
||||||
|
self.assertEqual(len(masked_weight), 1)
|
||||||
|
self.assertListEqual(masked_weight[0].get_shape().as_list(),
|
||||||
|
[input_depth, output_depth])
|
||||||
|
|
||||||
|
def testMultipleConvMaskAdded(self):
|
||||||
|
number_of_layers = 5
|
||||||
|
|
||||||
|
base_depth = 4
|
||||||
|
depth_step = 7
|
||||||
|
|
||||||
|
input_tensor = array_ops.ones((8, base_depth))
|
||||||
|
|
||||||
|
top_layer = input_tensor
|
||||||
|
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
top_layer = layers.masked_fully_connected(top_layer, base_depth +
|
||||||
|
(ix + 1) * depth_step)
|
||||||
|
|
||||||
|
masks = ops.get_collection(core_layers.MASK_COLLECTION)
|
||||||
|
self.assertEqual(len(masks), number_of_layers)
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
self.assertListEqual(masks[ix].get_shape().as_list(), [
|
||||||
|
base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
|
||||||
|
])
|
||||||
|
|
||||||
|
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
|
||||||
|
self.assertEqual(len(masked_weight), number_of_layers)
|
||||||
|
for ix in range(number_of_layers):
|
||||||
|
self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
|
||||||
|
base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
340
tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
Normal file
340
tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Module implementing RNN Cells with pruning.
|
||||||
|
|
||||||
|
This module implements BasicLSTMCell and LSTMCell with pruning.
|
||||||
|
Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import core_layers
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell as tf_rnn
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
|
||||||
|
"""Basic LSTM recurrent network cell with pruning.
|
||||||
|
|
||||||
|
Overrides the call method of tensorflow BasicLSTMCell and injects the weight
|
||||||
|
masks
|
||||||
|
|
||||||
|
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
||||||
|
|
||||||
|
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||||||
|
reduce the scale of forgetting in the beginning of the training.
|
||||||
|
|
||||||
|
It does not allow cell clipping, a projection layer, and does not
|
||||||
|
use peep-hole connections: it is the basic baseline.
|
||||||
|
|
||||||
|
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
|
||||||
|
that follows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_units,
|
||||||
|
forget_bias=1.0,
|
||||||
|
state_is_tuple=True,
|
||||||
|
activation=None,
|
||||||
|
reuse=None,
|
||||||
|
name=None):
|
||||||
|
"""Initialize the basic LSTM cell with pruning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the LSTM cell.
|
||||||
|
forget_bias: float, The bias added to forget gates (see above).
|
||||||
|
Must set to `0.0` manually when restoring from CudnnLSTM-trained
|
||||||
|
checkpoints.
|
||||||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||||
|
the `c_state` and `m_state`. If False, they are concatenated
|
||||||
|
along the column axis. The latter behavior will soon be deprecated.
|
||||||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
name: String, the name of the layer. Layers with the same name will
|
||||||
|
share weights, but to avoid mistakes we require reuse=True in such
|
||||||
|
cases.
|
||||||
|
|
||||||
|
When restoring from CudnnLSTM-trained checkpoints, must use
|
||||||
|
CudnnCompatibleLSTMCell instead.
|
||||||
|
"""
|
||||||
|
super(MaskedBasicLSTMCell, self).__init__(
|
||||||
|
num_units,
|
||||||
|
forget_bias=forget_bias,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
activation=activation,
|
||||||
|
reuse=reuse,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def build(self, inputs_shape):
|
||||||
|
# Call the build method of the parent class.
|
||||||
|
super(MaskedBasicLSTMCell, self).build(inputs_shape)
|
||||||
|
|
||||||
|
input_depth = inputs_shape[1].value
|
||||||
|
h_depth = self._num_units
|
||||||
|
self._mask = self.add_variable(
|
||||||
|
name="mask",
|
||||||
|
shape=[input_depth + h_depth, 4 * h_depth],
|
||||||
|
initializer=init_ops.ones_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
self._threshold = self.add_variable(
|
||||||
|
name="threshold",
|
||||||
|
shape=[],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
# Add masked_weights in the weights namescope so as to make it easier
|
||||||
|
# for the quantization library to add quant ops.
|
||||||
|
self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
|
||||||
|
core_layers.MASKED_WEIGHT_NAME)
|
||||||
|
if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
|
||||||
|
ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
|
||||||
|
ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
|
||||||
|
self._masked_kernel)
|
||||||
|
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
|
||||||
|
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Long short-term memory cell (LSTM) with masks for pruning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
|
||||||
|
state: An `LSTMStateTuple` of state tensors, each shaped
|
||||||
|
`[batch_size, self.state_size]`, if `state_is_tuple` has been set to
|
||||||
|
`True`. Otherwise, a `Tensor` shaped
|
||||||
|
`[batch_size, 2 * self.state_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair containing the new hidden state, and the new state (either a
|
||||||
|
`LSTMStateTuple` or a concatenated state, depending on
|
||||||
|
`state_is_tuple`).
|
||||||
|
"""
|
||||||
|
sigmoid = math_ops.sigmoid
|
||||||
|
one = constant_op.constant(1, dtype=dtypes.int32)
|
||||||
|
# Parameters of gates are concatenated into one multiply for efficiency.
|
||||||
|
if self._state_is_tuple:
|
||||||
|
c, h = state
|
||||||
|
else:
|
||||||
|
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
|
||||||
|
|
||||||
|
gate_inputs = math_ops.matmul(
|
||||||
|
array_ops.concat([inputs, h], 1), self._masked_kernel)
|
||||||
|
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
|
||||||
|
|
||||||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
|
i, j, f, o = array_ops.split(
|
||||||
|
value=gate_inputs, num_or_size_splits=4, axis=one)
|
||||||
|
|
||||||
|
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
|
||||||
|
# Note that using `add` and `multiply` instead of `+` and `*` gives a
|
||||||
|
# performance improvement. So using those at the cost of readability.
|
||||||
|
add = math_ops.add
|
||||||
|
multiply = math_ops.multiply
|
||||||
|
new_c = add(
|
||||||
|
multiply(c, sigmoid(add(f, forget_bias_tensor))),
|
||||||
|
multiply(sigmoid(i), self._activation(j)))
|
||||||
|
new_h = multiply(self._activation(new_c), sigmoid(o))
|
||||||
|
|
||||||
|
if self._state_is_tuple:
|
||||||
|
new_state = tf_rnn.LSTMStateTuple(new_c, new_h)
|
||||||
|
else:
|
||||||
|
new_state = array_ops.concat([new_c, new_h], 1)
|
||||||
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedLSTMCell(tf_rnn.LSTMCell):
|
||||||
|
"""LSTMCell with pruning.
|
||||||
|
|
||||||
|
Overrides the call method of tensorflow LSTMCell and injects the weight masks.
|
||||||
|
Masks are applied to only the weight matrix of the LSTM and not the
|
||||||
|
projection matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_units,
|
||||||
|
use_peepholes=False,
|
||||||
|
cell_clip=None,
|
||||||
|
initializer=None,
|
||||||
|
num_proj=None,
|
||||||
|
proj_clip=None,
|
||||||
|
num_unit_shards=None,
|
||||||
|
num_proj_shards=None,
|
||||||
|
forget_bias=1.0,
|
||||||
|
state_is_tuple=True,
|
||||||
|
activation=None,
|
||||||
|
reuse=None):
|
||||||
|
"""Initialize the parameters for an LSTM cell with masks for pruning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the LSTM cell
|
||||||
|
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
||||||
|
cell_clip: (optional) A float value, if provided the cell state is clipped
|
||||||
|
by this value prior to the cell output activation.
|
||||||
|
initializer: (optional) The initializer to use for the weight and
|
||||||
|
projection matrices.
|
||||||
|
num_proj: (optional) int, The output dimensionality for the projection
|
||||||
|
matrices. If None, no projection is performed.
|
||||||
|
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
||||||
|
provided, then the projected values are clipped elementwise to within
|
||||||
|
`[-proj_clip, proj_clip]`.
|
||||||
|
num_unit_shards: Deprecated, will be removed by Jan. 2017.
|
||||||
|
Use a variable_scope partitioner instead.
|
||||||
|
num_proj_shards: Deprecated, will be removed by Jan. 2017.
|
||||||
|
Use a variable_scope partitioner instead.
|
||||||
|
forget_bias: Biases of the forget gate are initialized by default to 1
|
||||||
|
in order to reduce the scale of forgetting at the beginning of
|
||||||
|
the training. Must set it manually to `0.0` when restoring from
|
||||||
|
CudnnLSTM trained checkpoints.
|
||||||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||||
|
the `c_state` and `m_state`. If False, they are concatenated
|
||||||
|
along the column axis. This latter behavior will soon be deprecated.
|
||||||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
|
||||||
|
When restoring from CudnnLSTM-trained checkpoints, must use
|
||||||
|
CudnnCompatibleLSTMCell instead.
|
||||||
|
"""
|
||||||
|
super(MaskedLSTMCell, self).__init__(
|
||||||
|
num_units,
|
||||||
|
use_peepholes=use_peepholes,
|
||||||
|
cell_clip=cell_clip,
|
||||||
|
initializer=initializer,
|
||||||
|
num_proj=num_proj,
|
||||||
|
proj_clip=proj_clip,
|
||||||
|
num_unit_shards=num_unit_shards,
|
||||||
|
num_proj_shards=num_proj_shards,
|
||||||
|
forget_bias=forget_bias,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
activation=activation,
|
||||||
|
reuse=reuse)
|
||||||
|
|
||||||
|
def build(self, inputs_shape):
|
||||||
|
# Call the build method of the parent class.
|
||||||
|
super(MaskedLSTMCell, self).build(inputs_shape)
|
||||||
|
|
||||||
|
input_depth = inputs_shape[1].value
|
||||||
|
h_depth = self._num_units
|
||||||
|
self._mask = self.add_variable(
|
||||||
|
name="mask",
|
||||||
|
shape=[input_depth + h_depth, 4 * h_depth],
|
||||||
|
initializer=init_ops.ones_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
self._threshold = self.add_variable(
|
||||||
|
name="threshold",
|
||||||
|
shape=[],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=self.dtype)
|
||||||
|
# Add masked_weights in the weights namescope so as to make it easier
|
||||||
|
# for the quantization library to add quant ops.
|
||||||
|
self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
|
||||||
|
core_layers.MASKED_WEIGHT_NAME)
|
||||||
|
if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
|
||||||
|
ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
|
||||||
|
ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
|
||||||
|
self._masked_kernel)
|
||||||
|
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
|
||||||
|
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run one step of LSTM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input Tensor, 2D, `[batch, num_units].
|
||||||
|
state: if `state_is_tuple` is False, this must be a state Tensor,
|
||||||
|
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
|
||||||
|
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
||||||
|
`m_state`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
|
||||||
|
- A `2-D, [batch, output_dim]`, Tensor representing the output of the
|
||||||
|
LSTM after reading `inputs` when previous state was `state`.
|
||||||
|
Here output_dim is:
|
||||||
|
num_proj if num_proj was set,
|
||||||
|
num_units otherwise.
|
||||||
|
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
||||||
|
the previous state was `state`. Same type and shape(s) as `state`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input size cannot be inferred from inputs via
|
||||||
|
static shape inference.
|
||||||
|
"""
|
||||||
|
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||||||
|
sigmoid = math_ops.sigmoid
|
||||||
|
|
||||||
|
if self._state_is_tuple:
|
||||||
|
(c_prev, m_prev) = state
|
||||||
|
else:
|
||||||
|
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
|
||||||
|
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
|
||||||
|
|
||||||
|
input_size = inputs.get_shape().with_rank(2)[1]
|
||||||
|
if input_size.value is None:
|
||||||
|
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||||
|
|
||||||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
|
lstm_matrix = math_ops.matmul(
|
||||||
|
array_ops.concat([inputs, m_prev], 1), self._masked_kernel)
|
||||||
|
lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
|
||||||
|
|
||||||
|
i, j, f, o = array_ops.split(
|
||||||
|
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||||
|
# Diagonal connections
|
||||||
|
if self._use_peepholes:
|
||||||
|
c = (
|
||||||
|
sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
|
||||||
|
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
|
||||||
|
else:
|
||||||
|
c = (
|
||||||
|
sigmoid(f + self._forget_bias) * c_prev +
|
||||||
|
sigmoid(i) * self._activation(j))
|
||||||
|
|
||||||
|
if self._cell_clip is not None:
|
||||||
|
# pylint: disable=invalid-unary-operand-type
|
||||||
|
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||||
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
if self._use_peepholes:
|
||||||
|
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
|
||||||
|
else:
|
||||||
|
m = sigmoid(o) * self._activation(c)
|
||||||
|
|
||||||
|
if self._num_proj is not None:
|
||||||
|
m = math_ops.matmul(m, self._proj_kernel)
|
||||||
|
|
||||||
|
if self._proj_clip is not None:
|
||||||
|
# pylint: disable=invalid-unary-operand-type
|
||||||
|
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||||
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
|
||||||
|
new_state = (
|
||||||
|
tf_rnn.LSTMStateTuple(c, m)
|
||||||
|
if self._state_is_tuple else array_ops.concat([c, m], 1))
|
||||||
|
return m, new_state
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for creating different number of masks in rnn_cells."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.python import pruning
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import rnn_cells
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell as tf_rnn_cells
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class RnnCellsTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(RnnCellsTest, self).setUp()
|
||||||
|
self.batch_size = 8
|
||||||
|
self.dim = 10
|
||||||
|
|
||||||
|
def testMaskedBasicLSTMCell(self):
|
||||||
|
expected_num_masks = 1
|
||||||
|
expected_num_rows = 2 * self.dim
|
||||||
|
expected_num_cols = 4 * self.dim
|
||||||
|
with self.test_session():
|
||||||
|
inputs = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
c = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
h = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
state = tf_rnn_cells.LSTMStateTuple(c, h)
|
||||||
|
lstm_cell = rnn_cells.MaskedBasicLSTMCell(self.dim)
|
||||||
|
lstm_cell(inputs, state)
|
||||||
|
self.assertEqual(len(pruning.get_masks()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_weights()), expected_num_masks)
|
||||||
|
|
||||||
|
for mask in pruning.get_masks():
|
||||||
|
self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
|
||||||
|
for weight in pruning.get_weights():
|
||||||
|
self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
|
||||||
|
|
||||||
|
def testMaskedLSTMCell(self):
|
||||||
|
expected_num_masks = 1
|
||||||
|
expected_num_rows = 2 * self.dim
|
||||||
|
expected_num_cols = 4 * self.dim
|
||||||
|
with self.test_session():
|
||||||
|
inputs = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
c = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
h = variables.Variable(
|
||||||
|
random_ops.random_normal([self.batch_size, self.dim]))
|
||||||
|
state = tf_rnn_cells.LSTMStateTuple(c, h)
|
||||||
|
lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
|
||||||
|
lstm_cell(inputs, state)
|
||||||
|
self.assertEqual(len(pruning.get_masks()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
|
||||||
|
self.assertEqual(len(pruning.get_weights()), expected_num_masks)
|
||||||
|
|
||||||
|
for mask in pruning.get_masks():
|
||||||
|
self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
|
||||||
|
for weight in pruning.get_weights():
|
||||||
|
self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
188
tensorflow/contrib/model_pruning/python/learning.py
Normal file
188
tensorflow/contrib/model_pruning/python/learning.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py
|
||||||
|
to support training of pruned models
|
||||||
|
|
||||||
|
*******************************************************************
|
||||||
|
* A simple working training script with support for model pruning *
|
||||||
|
*******************************************************************
|
||||||
|
|
||||||
|
# Load data and create the model:
|
||||||
|
images, labels = LoadData(...)
|
||||||
|
predictions = MyModel(images)
|
||||||
|
|
||||||
|
# Define the loss:
|
||||||
|
slim.losses.log_loss(predictions, labels)
|
||||||
|
total_loss = slim.losses.get_total_loss()
|
||||||
|
|
||||||
|
# Define the optimizer:
|
||||||
|
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
|
||||||
|
|
||||||
|
# Create the train_op
|
||||||
|
train_op = slim.learning.create_train_op(total_loss, optimizer)
|
||||||
|
|
||||||
|
# Set up sparsity
|
||||||
|
sparsity = pruning.setup_gradual_sparsity(self.global_step)
|
||||||
|
|
||||||
|
# Create mask update op
|
||||||
|
mask_update_op = pruning.add_mask_update_ip(sparsity)
|
||||||
|
|
||||||
|
# Run training.
|
||||||
|
learning.train(train_op,
|
||||||
|
my_log_dir,
|
||||||
|
mask_update_op)
|
||||||
|
see contrib/slim/python/slim/learning.py for additional examples
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib import slim as _slim
|
||||||
|
|
||||||
|
_USE_DEFAULT = 0
|
||||||
|
train_step = _slim.learning.train_step
|
||||||
|
|
||||||
|
|
||||||
|
def train(train_op,
|
||||||
|
logdir,
|
||||||
|
mask_update_op,
|
||||||
|
train_step_fn=train_step,
|
||||||
|
train_step_kwargs=_USE_DEFAULT,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
graph=None,
|
||||||
|
master='',
|
||||||
|
is_chief=True,
|
||||||
|
global_step=None,
|
||||||
|
number_of_steps=None,
|
||||||
|
init_op=_USE_DEFAULT,
|
||||||
|
init_feed_dict=None,
|
||||||
|
local_init_op=_USE_DEFAULT,
|
||||||
|
init_fn=None,
|
||||||
|
ready_op=_USE_DEFAULT,
|
||||||
|
summary_op=_USE_DEFAULT,
|
||||||
|
save_summaries_secs=600,
|
||||||
|
summary_writer=_USE_DEFAULT,
|
||||||
|
startup_delay_steps=0,
|
||||||
|
saver=None,
|
||||||
|
save_interval_secs=600,
|
||||||
|
sync_optimizer=None,
|
||||||
|
session_config=None,
|
||||||
|
trace_every_n_steps=None):
|
||||||
|
"""Wrapper around tf-slim's train function.
|
||||||
|
|
||||||
|
Runs a training loop using a TensorFlow supervisor.
|
||||||
|
When the sync_optimizer is supplied, gradient updates are applied
|
||||||
|
synchronously. Otherwise, gradient updates are applied asynchronous.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||||
|
return the loss value.
|
||||||
|
logdir: The directory where training logs are written to. If None, model
|
||||||
|
checkpoints and summaries will not be written.
|
||||||
|
mask_update_op: Operation that upon execution updates the weight masks and
|
||||||
|
thresholds.
|
||||||
|
train_step_fn: The function to call in order to execute a single gradient
|
||||||
|
step. The function must have take exactly four arguments: the current
|
||||||
|
session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
|
||||||
|
train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
|
||||||
|
default, two `Boolean`, scalar ops called "should_stop" and "should_log"
|
||||||
|
are provided.
|
||||||
|
log_every_n_steps: The frequency, in terms of global steps, that the loss
|
||||||
|
and global step and logged.
|
||||||
|
graph: The graph to pass to the supervisor. If no graph is supplied the
|
||||||
|
default graph is used.
|
||||||
|
master: The address of the tensorflow master.
|
||||||
|
is_chief: Specifies whether or not the training is being run by the primary
|
||||||
|
replica during replica training.
|
||||||
|
global_step: The `Tensor` representing the global step. If left as `None`,
|
||||||
|
then slim.variables.get_or_create_global_step() is used.
|
||||||
|
number_of_steps: The max number of gradient steps to take during training,
|
||||||
|
as measured by 'global_step': training will stop if global_step is
|
||||||
|
greater than 'number_of_steps'. If the value is left as None, training
|
||||||
|
proceeds indefinitely.
|
||||||
|
init_op: The initialization operation. If left to its default value, then
|
||||||
|
the session is initialized by calling `tf.global_variables_initializer()`.
|
||||||
|
init_feed_dict: A feed dictionary to use when executing the `init_op`.
|
||||||
|
local_init_op: The local initialization operation. If left to its default
|
||||||
|
value, then the session is initialized by calling
|
||||||
|
`tf.local_variables_initializer()` and `tf.tables_initializer()`.
|
||||||
|
init_fn: An optional callable to be executed after `init_op` is called. The
|
||||||
|
callable must accept one argument, the session being initialized.
|
||||||
|
ready_op: Operation to check if the model is ready to use. If left to its
|
||||||
|
default value, then the session checks for readiness by calling
|
||||||
|
`tf.report_uninitialized_variables()`.
|
||||||
|
summary_op: The summary operation.
|
||||||
|
save_summaries_secs: How often, in seconds, to save summaries.
|
||||||
|
summary_writer: `SummaryWriter` to use. Can be `None`
|
||||||
|
to indicate that no summaries should be written. If unset, we
|
||||||
|
create a SummaryWriter.
|
||||||
|
startup_delay_steps: The number of steps to wait for before beginning. Note
|
||||||
|
that this must be 0 if a sync_optimizer is supplied.
|
||||||
|
saver: Saver to save checkpoints. If None, a default one will be created
|
||||||
|
and used.
|
||||||
|
save_interval_secs: How often, in seconds, to save the model to `logdir`.
|
||||||
|
sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
|
||||||
|
them. If the argument is supplied, gradient updates will be synchronous.
|
||||||
|
If left as `None`, gradient updates will be asynchronous.
|
||||||
|
session_config: An instance of `tf.ConfigProto` that will be used to
|
||||||
|
configure the `Session`. If left as `None`, the default will be used.
|
||||||
|
trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
|
||||||
|
and add it to the summaries every `trace_every_n_steps`. If None, no trace
|
||||||
|
information will be produced or saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the value of the loss function after training.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `train_op` is empty or if `startup_delay_steps` is
|
||||||
|
non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
|
||||||
|
negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
|
||||||
|
provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def train_step_with_pruning_fn(sess, train_op, global_step,
|
||||||
|
train_step_kwargs):
|
||||||
|
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
|
||||||
|
train_step_kwargs)
|
||||||
|
sess.run(mask_update_op)
|
||||||
|
return total_loss, should_stop
|
||||||
|
|
||||||
|
total_loss, _ = _slim.learning.train(
|
||||||
|
train_op,
|
||||||
|
logdir,
|
||||||
|
train_step_fn=train_step_with_pruning_fn,
|
||||||
|
train_step_kwargs=train_step_kwargs,
|
||||||
|
log_every_n_steps=log_every_n_steps,
|
||||||
|
graph=graph,
|
||||||
|
master=master,
|
||||||
|
is_chief=is_chief,
|
||||||
|
global_step=global_step,
|
||||||
|
number_of_steps=number_of_steps,
|
||||||
|
init_op=init_op,
|
||||||
|
init_feed_dict=init_feed_dict,
|
||||||
|
local_init_op=local_init_op,
|
||||||
|
init_fn=init_fn,
|
||||||
|
ready_op=ready_op,
|
||||||
|
summary_op=summary_op,
|
||||||
|
save_summaries_secs=save_summaries_secs,
|
||||||
|
summary_writer=summary_writer,
|
||||||
|
startup_delay_steps=startup_delay_steps,
|
||||||
|
saver=saver,
|
||||||
|
save_interval_secs=save_interval_secs,
|
||||||
|
sync_optimizer=sync_optimizer,
|
||||||
|
session_config=session_config,
|
||||||
|
trace_every_n_steps=trace_every_n_steps)
|
||||||
|
|
||||||
|
return total_loss
|
585
tensorflow/contrib/model_pruning/python/pruning.py
Normal file
585
tensorflow/contrib/model_pruning/python/pruning.py
Normal file
@ -0,0 +1,585 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Helper functions to add support for magnitude-based model pruning.
|
||||||
|
|
||||||
|
# Adds variables and ops to the graph to enable
|
||||||
|
# elementwise masking of weights
|
||||||
|
apply_mask(weights)
|
||||||
|
|
||||||
|
# Returns a list containing the sparsity of each of the weight tensors
|
||||||
|
get_weight_sparsity()
|
||||||
|
|
||||||
|
# Returns a list of all the masked weight tensorflow variables
|
||||||
|
get_masked_weights()
|
||||||
|
|
||||||
|
# Returns a list of all the mask tensorflow variables
|
||||||
|
get_masks()
|
||||||
|
|
||||||
|
# Returns a list of all the thresholds
|
||||||
|
get_thresholds()
|
||||||
|
|
||||||
|
# Returns a list of all the weight tensors that have been masked
|
||||||
|
get_weights()
|
||||||
|
|
||||||
|
The Pruning class uses a proto (defined in pruning.proto) to set up the
|
||||||
|
parameters for a pruning specification. Here's a typical usage:
|
||||||
|
|
||||||
|
# Initialize a pruning spec from a proto
|
||||||
|
pruning_spec = '/tmp/pruning.pb'
|
||||||
|
p = Pruning(pruning_spec)
|
||||||
|
|
||||||
|
# Add mask update ops to the graph
|
||||||
|
mask_update_op = p.conditional_mask_update_op()
|
||||||
|
|
||||||
|
# Add the summaries
|
||||||
|
p.add_pruning_summaries()
|
||||||
|
|
||||||
|
# Run the op
|
||||||
|
session.run(mask_update_op)
|
||||||
|
|
||||||
|
# An object of the pruning also accepts externally defined sparsity:
|
||||||
|
sparsity = tf.Variable(0.5, name = "ConstantSparsity")
|
||||||
|
pruning_spec = '/tmp/pruning.pb'
|
||||||
|
p = Pruning(pruning_spec, sparsity=sparsity)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
|
||||||
|
from tensorflow.contrib.training.python.training import hparam
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import gen_array_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_impl
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.summary import summary
|
||||||
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
_MASK_COLLECTION = core.MASK_COLLECTION
|
||||||
|
_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
|
||||||
|
_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
|
||||||
|
_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
|
||||||
|
_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_mask_variable(var, scope):
|
||||||
|
"""Create a mask for the weights.
|
||||||
|
|
||||||
|
This function adds a variable 'mask' to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var: the weight variable that needs to be masked
|
||||||
|
scope: The variable scope of the variable var
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the mask variable of the same size and shape as var, initialized to all 1s.
|
||||||
|
"""
|
||||||
|
with variable_scope.variable_scope(scope):
|
||||||
|
mask = variable_scope.get_variable(
|
||||||
|
'mask',
|
||||||
|
var.get_shape(),
|
||||||
|
initializer=init_ops.ones_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=var.dtype)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_threshold_variable(var, scope):
|
||||||
|
"""Create a scalar threshold for the weights.
|
||||||
|
|
||||||
|
This function adds a variable
|
||||||
|
'threshold' to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var: The weight variable that needs to be masked
|
||||||
|
scope: The variable scope of the variable var
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a scalar threshold variable initialized to 0.
|
||||||
|
"""
|
||||||
|
with variable_scope.variable_scope(scope):
|
||||||
|
threshold = variable_scope.get_variable(
|
||||||
|
'threshold', [],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=var.dtype)
|
||||||
|
return threshold
|
||||||
|
|
||||||
|
|
||||||
|
def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
|
||||||
|
"""Return histogram of values.
|
||||||
|
|
||||||
|
Given the tensor `values`, this operation returns a rank 1 histogram counting
|
||||||
|
the number of entries in `values` that fell into every bin. The bins are
|
||||||
|
equal width and determined by the arguments `value_range` and `nbins`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: Numeric `Tensor`.
|
||||||
|
value_range: Shape [2] `Tensor` of same `dtype` as `values`.
|
||||||
|
values <= value_range[0] will be mapped to hist[0],
|
||||||
|
values >= value_range[1] will be mapped to hist[-1].
|
||||||
|
nbins: Scalar `int32 Tensor`. Number of histogram bins.
|
||||||
|
dtype: dtype for returned histogram.
|
||||||
|
name: A name for this operation (defaults to 'histogram').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 1-D `Tensor` holding histogram of values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
|
||||||
|
values = ops.convert_to_tensor(values, name='values')
|
||||||
|
values = gen_array_ops.reshape(values, [-1])
|
||||||
|
value_range = ops.convert_to_tensor(value_range, name='value_range')
|
||||||
|
nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins')
|
||||||
|
nbins_float = math_ops.cast(nbins, values.dtype)
|
||||||
|
|
||||||
|
# Map tensor values that fall within value_range to [0, 1].
|
||||||
|
scaled_values = math_ops.truediv(
|
||||||
|
values - value_range[0],
|
||||||
|
value_range[1] - value_range[0],
|
||||||
|
name='scaled_values')
|
||||||
|
|
||||||
|
# map tensor values within the open interval value_range to {0,.., nbins-1},
|
||||||
|
# values outside the open interval will be zero or less, or nbins or more.
|
||||||
|
indices = math_ops.floor(nbins_float * scaled_values, name='indices')
|
||||||
|
|
||||||
|
# Clip edge cases (e.g. value = value_range[1]) or "outliers."
|
||||||
|
indices = math_ops.cast(
|
||||||
|
clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32)
|
||||||
|
|
||||||
|
return math_ops.unsorted_segment_sum(
|
||||||
|
array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_partitioned_axis(partitioned_variable):
|
||||||
|
partitioned_axis = 0
|
||||||
|
concatenated_variable_shape = partitioned_variable.get_shape()
|
||||||
|
for partition in partitioned_variable:
|
||||||
|
partition_shape = partition.get_shape()
|
||||||
|
maybe_partitioned_axis = np.less(partition_shape,
|
||||||
|
concatenated_variable_shape)
|
||||||
|
# Sanity check: make sure number of partitioned axis == 1
|
||||||
|
if np.count_nonzero(maybe_partitioned_axis) != 1:
|
||||||
|
raise ValueError('Number of partitioned axes %s not equal to 1' %
|
||||||
|
np.count_nonzero(maybe_partitioned_axis))
|
||||||
|
partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
|
||||||
|
return partitioned_axis
|
||||||
|
|
||||||
|
|
||||||
|
def _variable_assign(var, new_value):
|
||||||
|
return state_ops.assign(var, new_value, name=var.op.name + '_assign')
|
||||||
|
|
||||||
|
|
||||||
|
def _partitioned_variable_assign(partitioned_var, new_value):
|
||||||
|
"""Assign op for partitioned variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partitioned_var: A partitioned tensotflow variable
|
||||||
|
new_value: Value to be assigned to the variable var
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensorflow op that groups the assign ops for each of the variable slices
|
||||||
|
"""
|
||||||
|
# Determine which axis was used to partition the variable. Currently
|
||||||
|
# tensorflow allows partitioning variable only along 1 axis.
|
||||||
|
axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis(
|
||||||
|
partitioned_var)
|
||||||
|
|
||||||
|
partition_sizes = np.array(
|
||||||
|
[partition.get_shape()[axis] for partition in partitioned_var])
|
||||||
|
new_partitioned_values = array_ops.split(
|
||||||
|
new_value,
|
||||||
|
ops.convert_to_tensor(partition_sizes, dtype=np.int32),
|
||||||
|
axis=axis)
|
||||||
|
op_list = []
|
||||||
|
for partition in partitioned_var:
|
||||||
|
op_list.append(
|
||||||
|
_variable_assign(partition, new_partitioned_values[len(op_list)]))
|
||||||
|
return control_flow_ops.group(
|
||||||
|
*op_list, name=partitioned_var.name + '_group_assign')
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mask(x, scope=''):
|
||||||
|
"""Apply mask to a given weight tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input weight tensor
|
||||||
|
scope: The current variable scope. Defaults to ""
|
||||||
|
Returns:
|
||||||
|
Tensor representing masked_weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask = _weight_mask_variable(x, scope)
|
||||||
|
threshold = _weight_threshold_variable(x, scope)
|
||||||
|
# Add masked_weights in the weights namescope so as to make it easier
|
||||||
|
# for the quantization library to add quant ops.
|
||||||
|
masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
|
||||||
|
|
||||||
|
# Make sure the mask for a given variable are not added multiple times to the
|
||||||
|
# collection. This is particularly important when applying mask to RNN's
|
||||||
|
# weight variables
|
||||||
|
if mask not in ops.get_collection_ref(_MASK_COLLECTION):
|
||||||
|
ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
|
||||||
|
ops.add_to_collection(_MASK_COLLECTION, mask)
|
||||||
|
ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
|
||||||
|
ops.add_to_collection(_WEIGHT_COLLECTION, x)
|
||||||
|
return masked_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_masked_weights():
|
||||||
|
return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
|
||||||
|
|
||||||
|
|
||||||
|
def get_masks():
|
||||||
|
return ops.get_collection(_MASK_COLLECTION)
|
||||||
|
|
||||||
|
|
||||||
|
def get_thresholds():
|
||||||
|
return ops.get_collection(_THRESHOLD_COLLECTION)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weights():
|
||||||
|
return ops.get_collection(_WEIGHT_COLLECTION)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_sparsity():
|
||||||
|
"""Get sparsity of the weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the sparsity of each of the weight tensors
|
||||||
|
"""
|
||||||
|
masks = get_masks()
|
||||||
|
return [nn_impl.zero_fraction(mask) for mask in masks]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pruning_hparams():
|
||||||
|
"""Get a tf.HParams object with the default values for the hyperparameters.
|
||||||
|
|
||||||
|
name: string
|
||||||
|
name of the pruning specification. Used for adding summaries and ops under
|
||||||
|
a common tensorflow name_scope
|
||||||
|
begin_pruning_step: integer
|
||||||
|
the global step at which to begin pruning
|
||||||
|
end_pruning_step: integer
|
||||||
|
the global step at which to terminate pruning. Defaults to -1 implying
|
||||||
|
that pruning continues till the training stops
|
||||||
|
do_not_prune: list of strings
|
||||||
|
list of layers that are not pruned
|
||||||
|
threshold_decay: float
|
||||||
|
the decay factor to use for exponential decay of the thresholds
|
||||||
|
pruning_frequency: integer
|
||||||
|
How often should the masks be updated? (in # of global_steps)
|
||||||
|
nbins: integer
|
||||||
|
number of bins to use for histogram computation
|
||||||
|
initial_sparsity: float
|
||||||
|
initial sparsity value
|
||||||
|
target_sparsity: float
|
||||||
|
target sparsity value
|
||||||
|
sparsity_function_begin_step: integer
|
||||||
|
the global step at this which the gradual sparsity function begins to
|
||||||
|
take effect
|
||||||
|
sparsity_function_end_step: integer
|
||||||
|
the global step used as the end point for the gradual sparsity function
|
||||||
|
sparsity_function_exponent: float
|
||||||
|
exponent = 1 is linearly varying sparsity between initial and final.
|
||||||
|
exponent > 1 varies more slowly towards the end than the beginning
|
||||||
|
|
||||||
|
We use the following sparsity function:
|
||||||
|
|
||||||
|
num_steps = (sparsity_function_end_step -
|
||||||
|
sparsity_function_begin_step)/pruning_frequency
|
||||||
|
sparsity(step) = (initial_sparsity - target_sparsity)*
|
||||||
|
[1-step/(num_steps -1)]**exponent + target_sparsity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.HParams object initialized to default values
|
||||||
|
|
||||||
|
"""
|
||||||
|
return hparam.HParams(
|
||||||
|
name='model_pruning',
|
||||||
|
begin_pruning_step=0,
|
||||||
|
end_pruning_step=-1,
|
||||||
|
do_not_prune=[''],
|
||||||
|
threshold_decay=0.9,
|
||||||
|
pruning_frequency=10,
|
||||||
|
nbins=255,
|
||||||
|
initial_sparsity=0,
|
||||||
|
target_sparsity=0.5,
|
||||||
|
sparsity_function_begin_step=0,
|
||||||
|
sparsity_function_end_step=100,
|
||||||
|
sparsity_function_exponent=3)
|
||||||
|
|
||||||
|
|
||||||
|
class Pruning(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
spec=None,
|
||||||
|
global_step=None,
|
||||||
|
sparsity=None,
|
||||||
|
partitioner=None):
|
||||||
|
"""Set up the specification for model pruning.
|
||||||
|
|
||||||
|
If a spec is provided, the sparsity is set up based on the sparsity_function
|
||||||
|
in the spec. The effect of sparsity_function is overridden if the sparsity
|
||||||
|
variable is passed to the constructor. This enables setting up arbitrary
|
||||||
|
sparsity profiles externally and passing it to this pruning functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: Pruning spec as defined in pruning.proto
|
||||||
|
global_step: A tensorflow variable that is used while setting up the
|
||||||
|
sparsity function
|
||||||
|
sparsity: A tensorflow scalar variable storing the sparsity
|
||||||
|
partitioner: The tensorflow partitioner function used to distribute
|
||||||
|
parameters across shards
|
||||||
|
"""
|
||||||
|
# Pruning specification
|
||||||
|
self._spec = spec if spec else get_pruning_hparams()
|
||||||
|
|
||||||
|
# A tensorflow variable that tracks the sparsity function.
|
||||||
|
# If not provided as input, the graph must already contain the global_step
|
||||||
|
# variable before calling this constructor.
|
||||||
|
self._global_step = self._setup_global_step(global_step)
|
||||||
|
|
||||||
|
# Stores the tensorflow sparsity variable.
|
||||||
|
# Built using self._setup_sparsity() or provided externally
|
||||||
|
self._sparsity = sparsity if sparsity else self._setup_sparsity()
|
||||||
|
|
||||||
|
# Stores the partitioner function uses to partition variables across tasks/
|
||||||
|
self._partitioner = partitioner
|
||||||
|
|
||||||
|
# List of tensorflow assignments ops for new masks and thresholds
|
||||||
|
self._assign_ops = []
|
||||||
|
|
||||||
|
# Tensorflow variable keeping track of the last global step when the masks
|
||||||
|
# were updated
|
||||||
|
self._last_update_step = self._setup_last_update_step()
|
||||||
|
|
||||||
|
def _setup_global_step(self, global_step):
|
||||||
|
graph_global_step = global_step
|
||||||
|
if graph_global_step is None:
|
||||||
|
graph_global_step = training_util.get_global_step()
|
||||||
|
|
||||||
|
return math_ops.cast(graph_global_step, np.int32)
|
||||||
|
|
||||||
|
def _setup_sparsity(self):
|
||||||
|
begin_step = self._spec.sparsity_function_begin_step
|
||||||
|
end_step = self._spec.sparsity_function_end_step
|
||||||
|
initial_sparsity = self._spec.initial_sparsity
|
||||||
|
target_sparsity = self._spec.target_sparsity
|
||||||
|
exponent = self._spec.sparsity_function_exponent
|
||||||
|
|
||||||
|
if begin_step >= end_step:
|
||||||
|
raise ValueError(
|
||||||
|
'Pruning must begin before it can end. begin_step=%d, end_step=%d' %
|
||||||
|
(begin_step, end_step))
|
||||||
|
|
||||||
|
with ops.name_scope(self._spec.name):
|
||||||
|
p = math_ops.minimum(1.0,
|
||||||
|
math_ops.maximum(
|
||||||
|
0.0,
|
||||||
|
math_ops.div(
|
||||||
|
math_ops.cast(self._global_step - begin_step,
|
||||||
|
np.float32),
|
||||||
|
end_step - begin_step)))
|
||||||
|
sparsity = math_ops.add(
|
||||||
|
math_ops.multiply(initial_sparsity - target_sparsity,
|
||||||
|
math_ops.pow(1 - p, exponent)),
|
||||||
|
target_sparsity,
|
||||||
|
name='sparsity')
|
||||||
|
|
||||||
|
return sparsity
|
||||||
|
|
||||||
|
def _setup_last_update_step(self):
|
||||||
|
with variable_scope.variable_scope(self._spec.name) as scope:
|
||||||
|
try:
|
||||||
|
last_update_step = variable_scope.get_variable(
|
||||||
|
'last_mask_update_step', [],
|
||||||
|
initializer=init_ops.zeros_initializer(),
|
||||||
|
trainable=False,
|
||||||
|
dtype=np.int32)
|
||||||
|
except ValueError:
|
||||||
|
scope.reuse_variables()
|
||||||
|
last_update_step = variable_scope.get_variable(
|
||||||
|
'last_mask_update_step', dtype=np.int32)
|
||||||
|
return last_update_step
|
||||||
|
|
||||||
|
def _exists_in_do_not_prune_list(self, tensor_name):
|
||||||
|
do_not_prune_list = self._spec.do_not_prune
|
||||||
|
if not do_not_prune_list[0]:
|
||||||
|
return False
|
||||||
|
for layer_name in do_not_prune_list:
|
||||||
|
if tensor_name.find(layer_name) != -1:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _update_mask(self, weights, threshold):
|
||||||
|
"""Updates the mask for a given weight tensor.
|
||||||
|
|
||||||
|
This functions first computes the cdf of the weight tensor, and estimates
|
||||||
|
the threshold value such that 'desired_sparsity' fraction of weights
|
||||||
|
have magnitude less than the threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: The weight tensor that needs to be masked.
|
||||||
|
threshold: The current threshold value. The function will compute a new
|
||||||
|
threshold and return the exponential moving average using the current
|
||||||
|
value of threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new_threshold: The new value of the threshold based on weights, and
|
||||||
|
desired_sparsity
|
||||||
|
new_mask: A n-D numpy array containing 0 or 1 to indicate which of the
|
||||||
|
values in weights falls below the threshold
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if sparsity is not defined
|
||||||
|
"""
|
||||||
|
if self._sparsity is None:
|
||||||
|
raise ValueError('Sparsity variable undefined')
|
||||||
|
|
||||||
|
with ops.name_scope(weights.op.name + '_pruning_ops'):
|
||||||
|
abs_weights = math_ops.abs(weights)
|
||||||
|
max_value = math_ops.reduce_max(abs_weights)
|
||||||
|
histogram = _histogram(
|
||||||
|
abs_weights, [0.0, max_value],
|
||||||
|
nbins=self._spec.nbins,
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
cdf = math_ops.cumsum(histogram)
|
||||||
|
norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram))
|
||||||
|
current_threshold = math_ops.multiply(
|
||||||
|
math_ops.div(
|
||||||
|
math_ops.reduce_sum(
|
||||||
|
math_ops.cast(
|
||||||
|
math_ops.less(norm_cdf, self._sparsity), np.float32)),
|
||||||
|
float(self._spec.nbins)), max_value)
|
||||||
|
|
||||||
|
smoothed_threshold = math_ops.add_n([
|
||||||
|
math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
|
||||||
|
math_ops.multiply(threshold, self._spec.threshold_decay)
|
||||||
|
])
|
||||||
|
new_mask = math_ops.cast(
|
||||||
|
math_ops.greater(abs_weights, smoothed_threshold), np.float32)
|
||||||
|
return smoothed_threshold, new_mask
|
||||||
|
|
||||||
|
def _get_mask_assign_ops(self):
|
||||||
|
# Make sure the assignment ops have not already been added to the list
|
||||||
|
if self._assign_ops:
|
||||||
|
raise ValueError(
|
||||||
|
'Assign op list not empty. _get_mask_assign_ops() called twice?')
|
||||||
|
|
||||||
|
masks = get_masks()
|
||||||
|
weights = get_weights()
|
||||||
|
thresholds = get_thresholds()
|
||||||
|
|
||||||
|
if len(masks) != len(thresholds):
|
||||||
|
raise ValueError(
|
||||||
|
'Number of masks %s and number of thresholds %s mismatch' %
|
||||||
|
(len(masks), len(thresholds)))
|
||||||
|
|
||||||
|
for index, mask in enumerate(masks):
|
||||||
|
threshold = thresholds[index]
|
||||||
|
weight = weights[index] if self._partitioner is None else weights[
|
||||||
|
index].as_tensor()
|
||||||
|
|
||||||
|
if self._spec.do_not_prune:
|
||||||
|
if self._exists_in_do_not_prune_list(mask.name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_threshold, new_mask = self._update_mask(weight, threshold)
|
||||||
|
self._assign_ops.append(_variable_assign(threshold, new_threshold))
|
||||||
|
self._assign_ops.append(
|
||||||
|
_variable_assign(mask, new_mask) if self._partitioner is None else
|
||||||
|
_partitioned_variable_assign(mask, new_mask))
|
||||||
|
|
||||||
|
def mask_update_op(self):
|
||||||
|
with ops.name_scope(self._spec.name):
|
||||||
|
if not self._assign_ops:
|
||||||
|
self._get_mask_assign_ops()
|
||||||
|
with ops.control_dependencies([
|
||||||
|
state_ops.assign(
|
||||||
|
self._last_update_step,
|
||||||
|
self._global_step,
|
||||||
|
name='last_mask_update_step_assign')
|
||||||
|
]):
|
||||||
|
with ops.control_dependencies(self._assign_ops):
|
||||||
|
logging.info('Updating masks.')
|
||||||
|
return control_flow_ops.no_op('mask_update')
|
||||||
|
|
||||||
|
def conditional_mask_update_op(self):
|
||||||
|
|
||||||
|
def maybe_update_masks():
|
||||||
|
with ops.name_scope(self._spec.name):
|
||||||
|
is_step_within_pruning_range = math_ops.logical_and(
|
||||||
|
math_ops.greater_equal(self._global_step,
|
||||||
|
self._spec.begin_pruning_step),
|
||||||
|
# If end_pruning_step is negative, keep pruning forever!
|
||||||
|
math_ops.logical_or(
|
||||||
|
math_ops.less_equal(self._global_step,
|
||||||
|
self._spec.end_pruning_step),
|
||||||
|
math_ops.less(self._spec.end_pruning_step, 0)))
|
||||||
|
is_pruning_step = math_ops.less_equal(
|
||||||
|
math_ops.add(self._last_update_step, self._spec.pruning_frequency),
|
||||||
|
self._global_step)
|
||||||
|
return math_ops.logical_and(is_step_within_pruning_range,
|
||||||
|
is_pruning_step)
|
||||||
|
|
||||||
|
def mask_update_op():
|
||||||
|
return self.mask_update_op()
|
||||||
|
|
||||||
|
def no_update_op():
|
||||||
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
|
return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
|
||||||
|
no_update_op)
|
||||||
|
|
||||||
|
def add_pruning_summaries(self):
|
||||||
|
"""Adds summaries for this pruning spec.
|
||||||
|
|
||||||
|
Args: none
|
||||||
|
|
||||||
|
Returns: none
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self._spec.name + '_summaries'):
|
||||||
|
summary.scalar('sparsity', self._sparsity)
|
||||||
|
summary.scalar('last_mask_update_step', self._last_update_step)
|
||||||
|
masks = get_masks()
|
||||||
|
thresholds = get_thresholds()
|
||||||
|
for index, mask in enumerate(masks):
|
||||||
|
if not self._exists_in_do_not_prune_list(mask.name):
|
||||||
|
summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
|
||||||
|
summary.scalar(thresholds[index].op.name + '/threshold',
|
||||||
|
thresholds[index])
|
||||||
|
|
||||||
|
def print_hparams(self):
|
||||||
|
logging.info(self._spec.to_json())
|
162
tensorflow/contrib/model_pruning/python/pruning_test.py
Normal file
162
tensorflow/contrib/model_pruning/python/pruning_test.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the key functions in pruning library."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.model_pruning.python import pruning
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import partitioned_variables
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
|
||||||
|
class PruningHParamsTest(test.TestCase):
|
||||||
|
PARAM_LIST = [
|
||||||
|
"name=test", "threshold_decay=0.9", "pruning_frequency=10",
|
||||||
|
"do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100",
|
||||||
|
"target_sparsity=0.9"
|
||||||
|
]
|
||||||
|
TEST_HPARAMS = ",".join(PARAM_LIST)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(PruningHParamsTest, self).setUp()
|
||||||
|
# Add global step variable to the graph
|
||||||
|
self.global_step = training_util.get_or_create_global_step()
|
||||||
|
# Add sparsity
|
||||||
|
self.sparsity = variables.Variable(0.5, name="sparsity")
|
||||||
|
# Parse hparams
|
||||||
|
self.pruning_hparams = pruning.get_pruning_hparams().parse(
|
||||||
|
self.TEST_HPARAMS)
|
||||||
|
|
||||||
|
def testInit(self):
|
||||||
|
p = pruning.Pruning(self.pruning_hparams)
|
||||||
|
self.assertEqual(p._spec.name, "test")
|
||||||
|
self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
|
||||||
|
self.assertEqual(p._spec.pruning_frequency, 10)
|
||||||
|
self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"])
|
||||||
|
self.assertEqual(p._spec.sparsity_function_end_step, 100)
|
||||||
|
self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
|
||||||
|
|
||||||
|
def testInitWithExternalSparsity(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
sparsity = p._sparsity.eval()
|
||||||
|
self.assertAlmostEqual(sparsity, 0.5)
|
||||||
|
|
||||||
|
def testInitWithVariableReuse(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
|
||||||
|
p_copy = pruning.Pruning(
|
||||||
|
spec=self.pruning_hparams, sparsity=self.sparsity)
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
sparsity = p._sparsity.eval()
|
||||||
|
self.assertAlmostEqual(sparsity, 0.5)
|
||||||
|
self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
|
||||||
|
|
||||||
|
|
||||||
|
class PruningTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(PruningTest, self).setUp()
|
||||||
|
self.global_step = training_util.get_or_create_global_step()
|
||||||
|
|
||||||
|
def testCreateMask2D(self):
|
||||||
|
width = 10
|
||||||
|
height = 20
|
||||||
|
with self.test_session():
|
||||||
|
weights = variables.Variable(
|
||||||
|
random_ops.random_normal([width, height], stddev=1), name="weights")
|
||||||
|
masked_weights = pruning.apply_mask(weights,
|
||||||
|
variable_scope.get_variable_scope())
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
weights_val = weights.eval()
|
||||||
|
masked_weights_val = masked_weights.eval()
|
||||||
|
self.assertAllEqual(weights_val, masked_weights_val)
|
||||||
|
|
||||||
|
def testUpdateSingleMask(self):
|
||||||
|
with self.test_session() as session:
|
||||||
|
weights = variables.Variable(
|
||||||
|
math_ops.linspace(1.0, 100.0, 100), name="weights")
|
||||||
|
masked_weights = pruning.apply_mask(weights)
|
||||||
|
sparsity = variables.Variable(0.5, name="sparsity")
|
||||||
|
p = pruning.Pruning(sparsity=sparsity)
|
||||||
|
p._spec.threshold_decay = 0.0
|
||||||
|
mask_update_op = p.mask_update_op()
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
masked_weights_val = masked_weights.eval()
|
||||||
|
self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
|
||||||
|
session.run(mask_update_op)
|
||||||
|
masked_weights_val = masked_weights.eval()
|
||||||
|
self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
|
||||||
|
|
||||||
|
def testPartitionedVariableMasking(self):
|
||||||
|
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
|
||||||
|
with self.test_session() as session:
|
||||||
|
with variable_scope.variable_scope("", partitioner=partitioner):
|
||||||
|
sparsity = variables.Variable(0.5, name="Sparsity")
|
||||||
|
weights = variable_scope.get_variable(
|
||||||
|
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
|
||||||
|
masked_weights = pruning.apply_mask(
|
||||||
|
weights, scope=variable_scope.get_variable_scope())
|
||||||
|
p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner)
|
||||||
|
p._spec.threshold_decay = 0.0
|
||||||
|
mask_update_op = p.mask_update_op()
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
masked_weights_val = masked_weights.eval()
|
||||||
|
session.run(mask_update_op)
|
||||||
|
masked_weights_val = masked_weights.eval()
|
||||||
|
self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
|
||||||
|
|
||||||
|
def testConditionalMaskUpdate(self):
|
||||||
|
param_list = [
|
||||||
|
"pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
|
||||||
|
]
|
||||||
|
test_spec = ",".join(param_list)
|
||||||
|
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
|
||||||
|
weights = variables.Variable(
|
||||||
|
math_ops.linspace(1.0, 100.0, 100), name="weights")
|
||||||
|
masked_weights = pruning.apply_mask(weights)
|
||||||
|
sparsity = variables.Variable(0.00, name="sparsity")
|
||||||
|
# Set up pruning
|
||||||
|
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
|
||||||
|
p._spec.threshold_decay = 0.0
|
||||||
|
mask_update_op = p.conditional_mask_update_op()
|
||||||
|
sparsity_val = math_ops.linspace(0.0, 0.9, 10)
|
||||||
|
increment_global_step = state_ops.assign_add(self.global_step, 1)
|
||||||
|
non_zero_count = []
|
||||||
|
with self.test_session() as session:
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
for i in range(10):
|
||||||
|
session.run(state_ops.assign(sparsity, sparsity_val[i]))
|
||||||
|
session.run(mask_update_op)
|
||||||
|
session.run(increment_global_step)
|
||||||
|
non_zero_count.append(np.count_nonzero(masked_weights.eval()))
|
||||||
|
# Weights pruned at steps 0,2,4,and,6
|
||||||
|
expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
|
||||||
|
self.assertAllEqual(expected_non_zero_count, non_zero_count)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user