Add trainable arg to fully_connected.

Add linear, relu, relu6.
Remove legacy_relu6 and legacy_convolution2d.
Change: 123898222
This commit is contained in:
A. Unique TensorFlower 2016-06-02 12:02:49 -08:00 committed by TensorFlower Gardener
parent 6cc37af4fb
commit f6acee434c
2 changed files with 45 additions and 285 deletions

View File

@ -48,14 +48,15 @@ __all__ = ['avg_pool2d',
'dropout',
'flatten',
'fully_connected',
'linear',
'max_pool2d',
'one_hot_encoding',
'relu',
'relu6',
'stack',
'legacy_convolution2d',
'legacy_fully_connected',
'legacy_linear',
'legacy_relu',
'legacy_relu6']
'legacy_relu']
@add_arg_scope
@ -107,6 +108,7 @@ def batch_norm(inputs,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@ -139,6 +141,8 @@ def batch_norm(inputs,
able to reuse the layer scope must be given.
variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`.
Returns:
@ -160,7 +164,8 @@ def batch_norm(inputs,
shape=params_shape,
dtype=dtype,
initializer=init_ops.zeros_initializer,
collections=beta_collections)
collections=beta_collections,
trainable=trainable)
if scale:
gamma_collections = utils.get_variable_collections(variables_collections,
'gamma')
@ -168,7 +173,8 @@ def batch_norm(inputs,
shape=params_shape,
dtype=dtype,
initializer=init_ops.ones_initializer,
collections=gamma_collections)
collections=gamma_collections,
trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
# appropiate collections.
moving_mean_collections = utils.get_variable_collections(
@ -226,6 +232,7 @@ def bias_add(inputs,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a bias to the inputs.
@ -242,6 +249,8 @@ def bias_add(inputs,
able to reuse the layer scope must be given.
variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope.
Returns:
@ -258,7 +267,8 @@ def bias_add(inputs,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
collections=biases_collections)
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(inputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
@ -281,6 +291,7 @@ def convolution2d(inputs,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a 2D convolution followed by an optional batch_norm layer.
@ -315,6 +326,8 @@ def convolution2d(inputs,
variables_collections: optional list of collections for all the variables or
a dictionay containing a different list of collection per variable.
outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`.
Returns:
@ -336,7 +349,8 @@ def convolution2d(inputs,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
collections=weights_collections)
collections=weights_collections,
trainable=trainable)
outputs = nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
padding=padding)
if normalizer_fn:
@ -351,7 +365,8 @@ def convolution2d(inputs,
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections)
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
@ -435,6 +450,7 @@ def fully_connected(inputs,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a fully connected layer.
@ -465,8 +481,10 @@ def fully_connected(inputs,
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or
a dictionay containing a different list of collection per variable.
a dictionary containing a different list of collections per variable.
outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope.
Returns:
@ -496,7 +514,8 @@ def fully_connected(inputs,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
collections=weights_collections)
collections=weights_collections,
trainable=trainable)
if len(static_shape) > 2:
# Reshape inputs
inputs = array_ops.reshape(inputs, [-1, num_input_units])
@ -513,7 +532,8 @@ def fully_connected(inputs,
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections)
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if len(static_shape) > 2:
# Reshape back outputs
@ -773,127 +793,13 @@ def legacy_fully_connected(x,
return _apply_activation(y, activation_fn, output_collections)
def legacy_convolution2d(x,
num_output_channels,
kernel_size,
activation_fn=None,
stride=(1, 1),
padding='SAME',
weight_init=initializers.xavier_initializer_conv2d(),
bias_init=standard_ops.zeros_initializer,
name=None,
weight_collections=(ops.GraphKeys.WEIGHTS,),
bias_collections=(ops.GraphKeys.BIASES,),
output_collections=(ops.GraphKeys.ACTIVATIONS,),
trainable=True,
weight_regularizer=None,
bias_regularizer=None):
# pylint: disable=g-docstring-has-escape
"""Adds the parameters for a conv2d layer and returns the output.
A neural network convolution layer is generally defined as:
\\\\(y = f(conv2d(w, x) + b)\\\\) where **f** is given by `activation_fn`,
**conv2d** is `tf.nn.conv2d` and `x` has shape
`[batch, height, width, channels]`. The output of this op is of shape
`[batch, out_height, out_width, num_output_channels]`, where `out_width` and
`out_height` are determined by the `padding` argument. See `conv2D` for
details.
This op creates `w` and optionally `b` and adds various summaries that can be
useful for visualizing learning or diagnosing training problems. Bias can be
disabled by setting `bias_init` to `None`.
The variable creation is compatible with `tf.variable_scope` and so can be
reused with `tf.variable_scope` or `tf.make_template`.
Most of the details of variable creation can be controlled by specifying the
initializers (`weight_init` and `bias_init`) and which collections to place
the created variables in (`weight_collections` and `bias_collections`).
A per layer regularization can be specified by setting `weight_regularizer`.
This is only applied to weights and not the bias.
Args:
x: A 4-D input `Tensor`.
num_output_channels: The number of output channels (i.e. the size of the
last dimension of the output).
kernel_size: A length 2 `list` or `tuple` containing the kernel size.
activation_fn: A function that requires a single Tensor that is applied as a
non-linearity.
stride: A length 2 `list` or `tuple` specifying the stride of the sliding
window across the image.
padding: A `string` from: "SAME", "VALID". The type of padding algorithm to
use.
weight_init: An optional initialization. If not specified, uses Xavier
initialization (see `tf.learn.xavier_initializer`).
bias_init: An initializer for the bias, defaults to 0. Set to`None` in order
to disable bias.
name: The name for this operation is used to name operations and to find
variables. If specified it must be unique for this scope, otherwise a
unique name starting with "convolution2d" will be created. See
`tf.variable_op_scope` for details.
weight_collections: List of graph collections to which weights are added.
bias_collections: List of graph collections to which biases are added.
output_collections: List of graph collections to which outputs are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
weight_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for weights.
bias_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for biases.
Returns:
The result of applying a 2-D convolutional layer.
Raises:
ValueError: If `kernel_size` or `stride` are not length 2.
"""
with variable_scope.variable_op_scope([x], name, 'convolution2d'):
num_input_channels = x.get_shape().dims[3].value
if len(kernel_size) != 2:
raise ValueError('kernel_size must be length 2: %d ' % kernel_size)
if len(stride) != 2:
raise ValueError('stride must be length 2: %d' % stride)
stride = [1, stride[0], stride[1], 1]
shape = [kernel_size[0], kernel_size[1], num_input_channels,
num_output_channels]
dtype = x.dtype.base_dtype
weight_collections = set(list(weight_collections or []) +
[ops.GraphKeys.VARIABLES])
w = variable_scope.get_variable('weights',
shape=shape,
dtype=dtype,
initializer=weight_init,
collections=weight_collections,
regularizer=weight_regularizer,
trainable=trainable)
y = nn.conv2d(x, w, stride, padding)
if bias_init is not None:
bias_collections = set(list(bias_collections or []) +
[ops.GraphKeys.VARIABLES])
b = variable_scope.get_variable('bias',
shape=[num_output_channels],
dtype=dtype,
initializer=bias_init,
collections=bias_collections,
regularizer=bias_regularizer,
trainable=trainable)
y = nn.bias_add(y, b)
return _apply_activation(y, activation_fn, output_collections)
# TODO(eiderm): Verify and fix autocomplete in colab (also relu6).
# Simple aliases which remove the activation_fn parameter.
legacy_relu = functools.partial(legacy_fully_connected, activation_fn=nn.relu)
legacy_relu6 = functools.partial(legacy_fully_connected, activation_fn=nn.relu6)
# Simple alias for fully_connected which removes the activation_fn parameter.
legacy_linear = functools.partial(legacy_fully_connected, activation_fn=None)
relu = functools.partial(fully_connected, activation_fn=nn.relu)
relu6 = functools.partial(fully_connected, activation_fn=nn.relu6)
linear = functools.partial(fully_connected, activation_fn=None)
# Simple alias for convolution2d.
conv2d = convolution2d

View File

@ -435,15 +435,16 @@ class FCTest(tf.test.TestCase):
def testCreateFC(self):
height, width = 3, 3
with self.test_session():
inputs = tf.random_uniform((5, height * width * 3), seed=1)
output = tf.contrib.layers.fully_connected(inputs, 32)
self.assertEquals(output.op.name, 'fully_connected/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 32])
weights = tf.contrib.framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32])
biases = tf.contrib.framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
for layer_fn in (tf.contrib.layers.fully_connected, tf.contrib.layers.relu):
with tf.Graph().as_default() as g, self.test_session(g):
inputs = tf.random_uniform((5, height * width * 3), seed=1)
output = layer_fn(inputs, 32)
self.assertEquals(output.op.name, 'fully_connected/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 32])
weights = tf.contrib.framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32])
biases = tf.contrib.framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateFCWithScope(self):
height, width = 3, 3
@ -985,27 +986,6 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_relu6_layer_basic_use(self):
output = tf.contrib.layers.legacy_relu6(self.input, 8)
with tf.Session() as sess:
with self.assertRaises(tf.errors.FailedPreconditionError):
sess.run(output)
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 8])
self.assertTrue(np.all(out_value >= 0),
'Relu6 should have all values >= 0.')
self.assertTrue(np.all(out_value <= 6),
'Relu6 should have all values <= 6.')
self.assertEqual(2,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_variable_reuse_with_scope(self):
with tf.variable_scope('test') as vs:
output1 = tf.contrib.layers.legacy_relu(self.input, 8)
@ -1208,131 +1188,5 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
activation_fn=tf.nn.softmax)
class LegacyConvolution2dTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
tf.set_random_seed(1234)
self.input = tf.constant(np.arange(2 * 3 * 3 * 4).reshape(
[2, 3, 3, 4]).astype(np.float32))
assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
def test_basic_use(self):
output = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.Session() as sess:
with self.assertRaises(tf.errors.FailedPreconditionError):
sess.run(output)
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 3, 3, 8])
self.assertTrue(np.all(out_value >= 0),
'Relu should have capped all values.')
self.assertEqual(2,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_variable_reuse_with_scope(self):
with tf.variable_scope('test') as vs:
output1 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
output2 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.variable_scope(vs, reuse=True):
output3 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
self.assertFalse(np.allclose(out_value1, out_value2))
self.assertAllClose(out_value1, out_value3)
def test_variable_reuse_with_template(self):
tmpl1 = tf.make_template('test',
tf.contrib.layers.legacy_convolution2d,
kernel_size=(3, 3),
num_output_channels=8)
output1 = tmpl1(self.input)
output2 = tmpl1(self.input)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2 = sess.run([output1, output2])
self.assertAllClose(out_value1, out_value2)
def test_custom_initializers(self):
output = tf.contrib.layers.legacy_convolution2d(
self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_init=tf.constant_initializer(2.0),
bias_init=tf.constant_initializer(1.0),
padding='VALID')
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertAllClose(
np.array([[[[1261., 1261.]]], [[[3853., 3853.]]]]), out_value)
def test_custom_collections(self):
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_collections=['unbiased'],
bias_collections=['biased'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
def test_all_custom_collections(self):
tf.contrib.layers.legacy_convolution2d(
self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_collections=['unbiased', 'all'],
bias_collections=['biased', 'all'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
self.assertEquals(2, len(tf.get_collection('all')))
self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
tf.get_collection('all'))
def test_regularizer(self):
cnt = [0]
tensor = tf.constant(5.0)
def test_fn(_):
cnt[0] += 1
return tensor
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
weight_regularizer=test_fn)
self.assertEqual([tensor],
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
def test_no_bias(self):
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
bias_init=None)
self.assertEqual(1,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
if __name__ == '__main__':
tf.test.main()