Extend block sparsity support for TPUs

PiperOrigin-RevId: 195685740
This commit is contained in:
A. Unique TensorFlower 2018-05-07 10:49:26 -07:00 committed by TensorFlower Gardener
parent b2888c66e6
commit 9ba26ca0d5
3 changed files with 116 additions and 27 deletions

View File

@ -396,14 +396,19 @@ class Pruning(object):
self._block_pooling_function)
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(
array_ops.reshape(weights, [
1,
squeezed_weights.get_shape()[0],
squeezed_weights.get_shape()[1], 1
]))
abs_weights = math_ops.abs(squeezed_weights)
pool_window = [self._block_dim[0], self._block_dim[1]]
pooled_weights = nn_ops.pool(
pool_fn = pruning_utils.factorized_pool
if not self._spec.use_tpu:
pool_fn = nn_ops.pool
abs_weights = array_ops.reshape(
abs_weights,
[1, abs_weights.get_shape()[0],
abs_weights.get_shape()[1], 1])
pooled_weights = pool_fn(
abs_weights,
window_shape=pool_window,
pooling_type=self._block_pooling_function,
@ -411,19 +416,18 @@ class Pruning(object):
padding='SAME',
name=weights.op.name + '_pooled')
if pooled_weights.get_shape().ndims != 2:
pooled_weights = array_ops.squeeze(pooled_weights)
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
reshaped_mask = array_ops.reshape(
new_mask,
[pooled_weights.get_shape()[1],
pooled_weights.get_shape()[2]])
updated_mask = pruning_utils.kronecker_product(
reshaped_mask, array_ops.ones(self._block_dim))
new_mask, array_ops.ones(self._block_dim))
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
squeezed_weights.get_shape()[1]])
return smoothed_threshold, array_ops.reshape(sliced_mask,
array_ops.shape(weights))

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs):
return math_ops.div(cdf, math_ops.reduce_max(cdf))
def factorized_pool(input_tensor,
window_shape,
pooling_type,
strides,
padding,
name=None):
"""Performs m x n pooling through a combination of 1xm and 1xn pooling.
Args:
input_tensor: Input tensor. Must be rank 2
window_shape: Pooling window shape
pooling_type: Either 'MAX' or 'AVG'
strides: The stride of the pooling window
padding: 'SAME' or 'VALID'.
name: Name of the op
Returns:
A rank 2 tensor containing the pooled output
Raises:
ValueError: if the input tensor is not rank 2
"""
if input_tensor.get_shape().ndims != 2:
raise ValueError('factorized_pool() accepts tensors of rank 2 only')
[height, width] = input_tensor.get_shape()
with ops.name_scope(name, 'factorized_pool'):
input_tensor_aligned = array_ops.reshape(
input_tensor, [1, 1, height, width],
name=input_tensor.op.name + '_aligned')
height_pooling = nn_ops.pool(
input_tensor_aligned,
window_shape=[1, window_shape[0]],
pooling_type=pooling_type,
strides=[1, strides[0]],
padding=padding)
swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
width_pooling = nn_ops.pool(
swap_height_width,
window_shape=[1, window_shape[1]],
pooling_type=pooling_type,
strides=[1, strides[1]],
padding=padding)
return array_ops.squeeze(
array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
def determine_partitioned_axis(partitioned_variable):
partitioned_axis = 0
concatenated_variable_shape = partitioned_variable.get_shape()

View File

@ -22,8 +22,10 @@ import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -31,6 +33,30 @@ from tensorflow.python.platform import test
class PruningUtilsTest(test.TestCase):
def _compare_cdf(self, values):
abs_values = math_ops.abs(values)
max_value = math_ops.reduce_max(abs_values)
with self.test_session():
variables.global_variables_initializer().run()
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
def _compare_pooling_methods(self, weights, pooling_kwargs):
with self.test_session():
variables.global_variables_initializer().run()
pooled_weights_tf = array_ops.squeeze(
nn_ops.pool(
array_ops.reshape(
weights,
[1, weights.get_shape()[0],
weights.get_shape()[1], 1]), **pooling_kwargs))
pooled_weights_factorized_pool = pruning_utils.factorized_pool(
weights, **pooling_kwargs)
self.assertAllClose(pooled_weights_tf.eval(),
pooled_weights_factorized_pool.eval())
def testHistogram(self):
width = 10
height = 10
@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase):
self.assertAllEqual(len(norm_cdf_val), nbins)
self.assertAllEqual(expected_cdf, norm_cdf_val)
def _compare_cdf(self, values):
abs_values = math_ops.abs(values)
max_value = math_ops.reduce_max(abs_values)
with self.test_session():
variables.global_variables_initializer().run()
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
return cdf.eval(), cdf_from_histogram.eval()
def testCDFEquivalence2D(self):
width = 100
height = 100
weights = variable_scope.get_variable("weights", shape=[width, height])
cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
self.assertAllEqual(cdf_val, cdf_from_histogram_val)
self._compare_cdf(weights)
def testCDFEquivalence4D(self):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
self.assertAllEqual(cdf_val, cdf_from_histogram_val)
self._compare_cdf(weights)
def testFactorizedAvgPool(self):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
"window_shape": [2, 4],
"pooling_type": "AVG",
"strides": [2, 4],
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
def testFactorizedMaxPool(self):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
"window_shape": [2, 4],
"pooling_type": "MAX",
"strides": [2, 4],
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
if __name__ == "__main__":