Add stack function to easily stack multiple layers of the same kind together.
Change: 123443759
This commit is contained in:
parent
d1a8cfc72c
commit
23c79213ec
@ -50,6 +50,7 @@ __all__ = ['avg_pool2d',
|
||||
'fully_connected',
|
||||
'max_pool2d',
|
||||
'one_hot_encoding',
|
||||
'stack',
|
||||
'legacy_convolution2d',
|
||||
'legacy_fully_connected',
|
||||
'legacy_linear',
|
||||
@ -600,6 +601,49 @@ def _apply_activation(y, activation_fn, output_collections):
|
||||
return y
|
||||
|
||||
|
||||
def stack(inputs, layer, stack_args, **kwargs):
|
||||
"""Builds a stack of layers by applying layer repeatedly using stack_args.
|
||||
|
||||
`stack` allows you to repeatedly apply the same operation with different
|
||||
arguments `stack_args[i]`. For each application of the layer, `stack` creates
|
||||
a new scope appended with an increasing number. For example:
|
||||
|
||||
```python
|
||||
stack(x, fully_connected, [32, 64, 128], scope='fc')
|
||||
# It is equivalent to:
|
||||
|
||||
x = fully_connected(x, 32, scope='fc/fc_1')
|
||||
x = fully_connected(x, 64, scope='fc/fc_2')
|
||||
x = fully_connected(x, 128, scope='fc/fc_3')
|
||||
```
|
||||
|
||||
Args:
|
||||
inputs: A `Tensor` suitable for layer.
|
||||
layer: A layer(inputs, *args, **kwargs)
|
||||
stack_args: A list/tuple of parameters for each call of layer.
|
||||
**kwargs: Extra kwargs for the layer.
|
||||
|
||||
Returns:
|
||||
a `Tensor` result of applying the stacked layers.
|
||||
|
||||
Raises:
|
||||
ValueError: if the op is unknown or wrong.
|
||||
"""
|
||||
scope = kwargs.pop('scope', None)
|
||||
if not isinstance(stack_args, (list, tuple)):
|
||||
raise ValueError('stack_args need to be a list or tuple')
|
||||
with variable_scope.variable_op_scope([inputs], scope, 'Stack'):
|
||||
outputs = inputs
|
||||
scope = scope or layer.__name__
|
||||
for i in range(len(stack_args)):
|
||||
kwargs['scope'] = scope + '_' + str(i+1)
|
||||
layer_args = stack_args[i]
|
||||
if not isinstance(layer_args, (list, tuple)):
|
||||
layer_args = [layer_args]
|
||||
outputs = layer(outputs, *layer_args, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
def legacy_fully_connected(x,
|
||||
num_output_units,
|
||||
activation_fn=None,
|
||||
|
@ -878,6 +878,44 @@ class OneHotEncodingTest(tf.test.TestCase):
|
||||
self.assertAllClose(output.eval(), one_hot_labels.eval())
|
||||
|
||||
|
||||
class StackTests(tf.test.TestCase):
|
||||
|
||||
def testStackFullyConnected(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform((5, height * width * 3), seed=1, name='images')
|
||||
output = tf.contrib.layers.stack(images,
|
||||
tf.contrib.layers.fully_connected,
|
||||
[10, 20, 30])
|
||||
self.assertEquals(output.op.name, 'Stack/fully_connected_3/Relu')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, 30])
|
||||
|
||||
def testStackConvolution2d(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
|
||||
output = tf.contrib.layers.stack(images,
|
||||
tf.contrib.layers.convolution2d,
|
||||
[10, 20, 30],
|
||||
kernel_size=[3, 3],
|
||||
padding='SAME')
|
||||
self.assertEquals(output.op.name, 'Stack/convolution2d_3/Relu')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30])
|
||||
|
||||
def testStackWithScope(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
|
||||
output = tf.contrib.layers.stack(images,
|
||||
tf.contrib.layers.convolution2d,
|
||||
[10, 20, 30],
|
||||
kernel_size=[3, 3],
|
||||
padding='SAME',
|
||||
scope='conv1')
|
||||
self.assertEquals(output.op.name, 'conv1/conv1_3/Relu')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30])
|
||||
|
||||
|
||||
# TODO(b/28426988): Add separate tests for non-legacy versions.
|
||||
class LegacyFullyConnectedTest(tf.test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user