Add stack function to easily stack multiple layers of the same kind together.

Change: 123443759
This commit is contained in:
A. Unique TensorFlower 2016-05-27 11:26:57 -08:00 committed by TensorFlower Gardener
parent d1a8cfc72c
commit 23c79213ec
2 changed files with 82 additions and 0 deletions

View File

@ -50,6 +50,7 @@ __all__ = ['avg_pool2d',
'fully_connected',
'max_pool2d',
'one_hot_encoding',
'stack',
'legacy_convolution2d',
'legacy_fully_connected',
'legacy_linear',
@ -600,6 +601,49 @@ def _apply_activation(y, activation_fn, output_collections):
return y
def stack(inputs, layer, stack_args, **kwargs):
"""Builds a stack of layers by applying layer repeatedly using stack_args.
`stack` allows you to repeatedly apply the same operation with different
arguments `stack_args[i]`. For each application of the layer, `stack` creates
a new scope appended with an increasing number. For example:
```python
stack(x, fully_connected, [32, 64, 128], scope='fc')
# It is equivalent to:
x = fully_connected(x, 32, scope='fc/fc_1')
x = fully_connected(x, 64, scope='fc/fc_2')
x = fully_connected(x, 128, scope='fc/fc_3')
```
Args:
inputs: A `Tensor` suitable for layer.
layer: A layer(inputs, *args, **kwargs)
stack_args: A list/tuple of parameters for each call of layer.
**kwargs: Extra kwargs for the layer.
Returns:
a `Tensor` result of applying the stacked layers.
Raises:
ValueError: if the op is unknown or wrong.
"""
scope = kwargs.pop('scope', None)
if not isinstance(stack_args, (list, tuple)):
raise ValueError('stack_args need to be a list or tuple')
with variable_scope.variable_op_scope([inputs], scope, 'Stack'):
outputs = inputs
scope = scope or layer.__name__
for i in range(len(stack_args)):
kwargs['scope'] = scope + '_' + str(i+1)
layer_args = stack_args[i]
if not isinstance(layer_args, (list, tuple)):
layer_args = [layer_args]
outputs = layer(outputs, *layer_args, **kwargs)
return outputs
def legacy_fully_connected(x,
num_output_units,
activation_fn=None,

View File

@ -878,6 +878,44 @@ class OneHotEncodingTest(tf.test.TestCase):
self.assertAllClose(output.eval(), one_hot_labels.eval())
class StackTests(tf.test.TestCase):
def testStackFullyConnected(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height * width * 3), seed=1, name='images')
output = tf.contrib.layers.stack(images,
tf.contrib.layers.fully_connected,
[10, 20, 30])
self.assertEquals(output.op.name, 'Stack/fully_connected_3/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 30])
def testStackConvolution2d(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
output = tf.contrib.layers.stack(images,
tf.contrib.layers.convolution2d,
[10, 20, 30],
kernel_size=[3, 3],
padding='SAME')
self.assertEquals(output.op.name, 'Stack/convolution2d_3/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30])
def testStackWithScope(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
output = tf.contrib.layers.stack(images,
tf.contrib.layers.convolution2d,
[10, 20, 30],
kernel_size=[3, 3],
padding='SAME',
scope='conv1')
self.assertEquals(output.op.name, 'conv1/conv1_3/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30])
# TODO(b/28426988): Add separate tests for non-legacy versions.
class LegacyFullyConnectedTest(tf.test.TestCase):