Extend block sparsity support for TPUs
PiperOrigin-RevId: 195685740
This commit is contained in:
parent
b2888c66e6
commit
9ba26ca0d5
@ -396,14 +396,19 @@ class Pruning(object):
|
||||
self._block_pooling_function)
|
||||
|
||||
with ops.name_scope(weights.op.name + '_pruning_ops'):
|
||||
abs_weights = math_ops.abs(
|
||||
array_ops.reshape(weights, [
|
||||
1,
|
||||
squeezed_weights.get_shape()[0],
|
||||
squeezed_weights.get_shape()[1], 1
|
||||
]))
|
||||
abs_weights = math_ops.abs(squeezed_weights)
|
||||
|
||||
pool_window = [self._block_dim[0], self._block_dim[1]]
|
||||
pooled_weights = nn_ops.pool(
|
||||
pool_fn = pruning_utils.factorized_pool
|
||||
|
||||
if not self._spec.use_tpu:
|
||||
pool_fn = nn_ops.pool
|
||||
abs_weights = array_ops.reshape(
|
||||
abs_weights,
|
||||
[1, abs_weights.get_shape()[0],
|
||||
abs_weights.get_shape()[1], 1])
|
||||
|
||||
pooled_weights = pool_fn(
|
||||
abs_weights,
|
||||
window_shape=pool_window,
|
||||
pooling_type=self._block_pooling_function,
|
||||
@ -411,19 +416,18 @@ class Pruning(object):
|
||||
padding='SAME',
|
||||
name=weights.op.name + '_pooled')
|
||||
|
||||
if pooled_weights.get_shape().ndims != 2:
|
||||
pooled_weights = array_ops.squeeze(pooled_weights)
|
||||
|
||||
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
|
||||
threshold)
|
||||
|
||||
reshaped_mask = array_ops.reshape(
|
||||
new_mask,
|
||||
[pooled_weights.get_shape()[1],
|
||||
pooled_weights.get_shape()[2]])
|
||||
updated_mask = pruning_utils.kronecker_product(
|
||||
reshaped_mask, array_ops.ones(self._block_dim))
|
||||
new_mask, array_ops.ones(self._block_dim))
|
||||
sliced_mask = array_ops.slice(
|
||||
updated_mask, [0, 0],
|
||||
[squeezed_weights.get_shape()[0],
|
||||
squeezed_weights.get_shape()[1]])
|
||||
|
||||
return smoothed_threshold, array_ops.reshape(sliced_mask,
|
||||
array_ops.shape(weights))
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs):
|
||||
return math_ops.div(cdf, math_ops.reduce_max(cdf))
|
||||
|
||||
|
||||
def factorized_pool(input_tensor,
|
||||
window_shape,
|
||||
pooling_type,
|
||||
strides,
|
||||
padding,
|
||||
name=None):
|
||||
"""Performs m x n pooling through a combination of 1xm and 1xn pooling.
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor. Must be rank 2
|
||||
window_shape: Pooling window shape
|
||||
pooling_type: Either 'MAX' or 'AVG'
|
||||
strides: The stride of the pooling window
|
||||
padding: 'SAME' or 'VALID'.
|
||||
name: Name of the op
|
||||
|
||||
Returns:
|
||||
A rank 2 tensor containing the pooled output
|
||||
|
||||
Raises:
|
||||
ValueError: if the input tensor is not rank 2
|
||||
"""
|
||||
if input_tensor.get_shape().ndims != 2:
|
||||
raise ValueError('factorized_pool() accepts tensors of rank 2 only')
|
||||
|
||||
[height, width] = input_tensor.get_shape()
|
||||
with ops.name_scope(name, 'factorized_pool'):
|
||||
input_tensor_aligned = array_ops.reshape(
|
||||
input_tensor, [1, 1, height, width],
|
||||
name=input_tensor.op.name + '_aligned')
|
||||
|
||||
height_pooling = nn_ops.pool(
|
||||
input_tensor_aligned,
|
||||
window_shape=[1, window_shape[0]],
|
||||
pooling_type=pooling_type,
|
||||
strides=[1, strides[0]],
|
||||
padding=padding)
|
||||
swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
|
||||
|
||||
width_pooling = nn_ops.pool(
|
||||
swap_height_width,
|
||||
window_shape=[1, window_shape[1]],
|
||||
pooling_type=pooling_type,
|
||||
strides=[1, strides[1]],
|
||||
padding=padding)
|
||||
|
||||
return array_ops.squeeze(
|
||||
array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
|
||||
|
||||
|
||||
def determine_partitioned_axis(partitioned_variable):
|
||||
partitioned_axis = 0
|
||||
concatenated_variable_shape = partitioned_variable.get_shape()
|
||||
|
@ -22,8 +22,10 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.model_pruning.python import pruning_utils
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -31,6 +33,30 @@ from tensorflow.python.platform import test
|
||||
|
||||
class PruningUtilsTest(test.TestCase):
|
||||
|
||||
def _compare_cdf(self, values):
|
||||
abs_values = math_ops.abs(values)
|
||||
max_value = math_ops.reduce_max(abs_values)
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
|
||||
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
|
||||
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
|
||||
self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
|
||||
|
||||
def _compare_pooling_methods(self, weights, pooling_kwargs):
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
pooled_weights_tf = array_ops.squeeze(
|
||||
nn_ops.pool(
|
||||
array_ops.reshape(
|
||||
weights,
|
||||
[1, weights.get_shape()[0],
|
||||
weights.get_shape()[1], 1]), **pooling_kwargs))
|
||||
pooled_weights_factorized_pool = pruning_utils.factorized_pool(
|
||||
weights, **pooling_kwargs)
|
||||
self.assertAllClose(pooled_weights_tf.eval(),
|
||||
pooled_weights_factorized_pool.eval())
|
||||
|
||||
def testHistogram(self):
|
||||
width = 10
|
||||
height = 10
|
||||
@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase):
|
||||
self.assertAllEqual(len(norm_cdf_val), nbins)
|
||||
self.assertAllEqual(expected_cdf, norm_cdf_val)
|
||||
|
||||
def _compare_cdf(self, values):
|
||||
abs_values = math_ops.abs(values)
|
||||
max_value = math_ops.reduce_max(abs_values)
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
|
||||
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
|
||||
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
|
||||
return cdf.eval(), cdf_from_histogram.eval()
|
||||
|
||||
def testCDFEquivalence2D(self):
|
||||
width = 100
|
||||
height = 100
|
||||
weights = variable_scope.get_variable("weights", shape=[width, height])
|
||||
cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
|
||||
self.assertAllEqual(cdf_val, cdf_from_histogram_val)
|
||||
self._compare_cdf(weights)
|
||||
|
||||
def testCDFEquivalence4D(self):
|
||||
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
|
||||
cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
|
||||
self.assertAllEqual(cdf_val, cdf_from_histogram_val)
|
||||
self._compare_cdf(weights)
|
||||
|
||||
def testFactorizedAvgPool(self):
|
||||
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
|
||||
pooling_kwargs = {
|
||||
"window_shape": [2, 4],
|
||||
"pooling_type": "AVG",
|
||||
"strides": [2, 4],
|
||||
"padding": "SAME"
|
||||
}
|
||||
self._compare_pooling_methods(weights, pooling_kwargs)
|
||||
|
||||
def testFactorizedMaxPool(self):
|
||||
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
|
||||
pooling_kwargs = {
|
||||
"window_shape": [2, 4],
|
||||
"pooling_type": "MAX",
|
||||
"strides": [2, 4],
|
||||
"padding": "SAME"
|
||||
}
|
||||
self._compare_pooling_methods(weights, pooling_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user