Add trainable arg to fully_connected.

Add linear, relu, relu6.
Remove legacy_relu6 and legacy_convolution2d.
Change: 123898222
This commit is contained in:
A. Unique TensorFlower 2016-06-02 12:02:49 -08:00 committed by TensorFlower Gardener
parent 6cc37af4fb
commit f6acee434c
2 changed files with 45 additions and 285 deletions

View File

@ -48,14 +48,15 @@ __all__ = ['avg_pool2d',
'dropout', 'dropout',
'flatten', 'flatten',
'fully_connected', 'fully_connected',
'linear',
'max_pool2d', 'max_pool2d',
'one_hot_encoding', 'one_hot_encoding',
'relu',
'relu6',
'stack', 'stack',
'legacy_convolution2d',
'legacy_fully_connected', 'legacy_fully_connected',
'legacy_linear', 'legacy_linear',
'legacy_relu', 'legacy_relu']
'legacy_relu6']
@add_arg_scope @add_arg_scope
@ -107,6 +108,7 @@ def batch_norm(inputs,
reuse=None, reuse=None,
variables_collections=None, variables_collections=None,
outputs_collections=None, outputs_collections=None,
trainable=True,
scope=None): scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@ -139,6 +141,8 @@ def batch_norm(inputs,
able to reuse the layer scope must be given. able to reuse the layer scope must be given.
variables_collections: optional collections for the variables. variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs. outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`. scope: Optional scope for `variable_op_scope`.
Returns: Returns:
@ -160,7 +164,8 @@ def batch_norm(inputs,
shape=params_shape, shape=params_shape,
dtype=dtype, dtype=dtype,
initializer=init_ops.zeros_initializer, initializer=init_ops.zeros_initializer,
collections=beta_collections) collections=beta_collections,
trainable=trainable)
if scale: if scale:
gamma_collections = utils.get_variable_collections(variables_collections, gamma_collections = utils.get_variable_collections(variables_collections,
'gamma') 'gamma')
@ -168,7 +173,8 @@ def batch_norm(inputs,
shape=params_shape, shape=params_shape,
dtype=dtype, dtype=dtype,
initializer=init_ops.ones_initializer, initializer=init_ops.ones_initializer,
collections=gamma_collections) collections=gamma_collections,
trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the # Create moving_mean and moving_variance variables and add them to the
# appropiate collections. # appropiate collections.
moving_mean_collections = utils.get_variable_collections( moving_mean_collections = utils.get_variable_collections(
@ -226,6 +232,7 @@ def bias_add(inputs,
reuse=None, reuse=None,
variables_collections=None, variables_collections=None,
outputs_collections=None, outputs_collections=None,
trainable=True,
scope=None): scope=None):
"""Adds a bias to the inputs. """Adds a bias to the inputs.
@ -242,6 +249,8 @@ def bias_add(inputs,
able to reuse the layer scope must be given. able to reuse the layer scope must be given.
variables_collections: optional collections for the variables. variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs. outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope. scope: Optional scope for variable_op_scope.
Returns: Returns:
@ -258,7 +267,8 @@ def bias_add(inputs,
dtype=dtype, dtype=dtype,
initializer=initializer, initializer=initializer,
regularizer=regularizer, regularizer=regularizer,
collections=biases_collections) collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(inputs, biases) outputs = nn.bias_add(inputs, biases)
if activation_fn: if activation_fn:
outputs = activation_fn(outputs) outputs = activation_fn(outputs)
@ -281,6 +291,7 @@ def convolution2d(inputs,
reuse=None, reuse=None,
variables_collections=None, variables_collections=None,
outputs_collections=None, outputs_collections=None,
trainable=True,
scope=None): scope=None):
"""Adds a 2D convolution followed by an optional batch_norm layer. """Adds a 2D convolution followed by an optional batch_norm layer.
@ -315,6 +326,8 @@ def convolution2d(inputs,
variables_collections: optional list of collections for all the variables or variables_collections: optional list of collections for all the variables or
a dictionay containing a different list of collection per variable. a dictionay containing a different list of collection per variable.
outputs_collections: collection to add the outputs. outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`. scope: Optional scope for `variable_op_scope`.
Returns: Returns:
@ -336,7 +349,8 @@ def convolution2d(inputs,
dtype=dtype, dtype=dtype,
initializer=weights_initializer, initializer=weights_initializer,
regularizer=weights_regularizer, regularizer=weights_regularizer,
collections=weights_collections) collections=weights_collections,
trainable=trainable)
outputs = nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1], outputs = nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
padding=padding) padding=padding)
if normalizer_fn: if normalizer_fn:
@ -351,7 +365,8 @@ def convolution2d(inputs,
dtype=dtype, dtype=dtype,
initializer=biases_initializer, initializer=biases_initializer,
regularizer=biases_regularizer, regularizer=biases_regularizer,
collections=biases_collections) collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases) outputs = nn.bias_add(outputs, biases)
if activation_fn: if activation_fn:
outputs = activation_fn(outputs) outputs = activation_fn(outputs)
@ -435,6 +450,7 @@ def fully_connected(inputs,
reuse=None, reuse=None,
variables_collections=None, variables_collections=None,
outputs_collections=None, outputs_collections=None,
trainable=True,
scope=None): scope=None):
"""Adds a fully connected layer. """Adds a fully connected layer.
@ -465,8 +481,10 @@ def fully_connected(inputs,
reuse: whether or not the layer and its variables should be reused. To be reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given. able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or variables_collections: Optional list of collections for all the variables or
a dictionay containing a different list of collection per variable. a dictionary containing a different list of collections per variable.
outputs_collections: collection to add the outputs. outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope. scope: Optional scope for variable_op_scope.
Returns: Returns:
@ -496,7 +514,8 @@ def fully_connected(inputs,
dtype=dtype, dtype=dtype,
initializer=weights_initializer, initializer=weights_initializer,
regularizer=weights_regularizer, regularizer=weights_regularizer,
collections=weights_collections) collections=weights_collections,
trainable=trainable)
if len(static_shape) > 2: if len(static_shape) > 2:
# Reshape inputs # Reshape inputs
inputs = array_ops.reshape(inputs, [-1, num_input_units]) inputs = array_ops.reshape(inputs, [-1, num_input_units])
@ -513,7 +532,8 @@ def fully_connected(inputs,
dtype=dtype, dtype=dtype,
initializer=biases_initializer, initializer=biases_initializer,
regularizer=biases_regularizer, regularizer=biases_regularizer,
collections=biases_collections) collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases) outputs = nn.bias_add(outputs, biases)
if len(static_shape) > 2: if len(static_shape) > 2:
# Reshape back outputs # Reshape back outputs
@ -773,127 +793,13 @@ def legacy_fully_connected(x,
return _apply_activation(y, activation_fn, output_collections) return _apply_activation(y, activation_fn, output_collections)
def legacy_convolution2d(x,
num_output_channels,
kernel_size,
activation_fn=None,
stride=(1, 1),
padding='SAME',
weight_init=initializers.xavier_initializer_conv2d(),
bias_init=standard_ops.zeros_initializer,
name=None,
weight_collections=(ops.GraphKeys.WEIGHTS,),
bias_collections=(ops.GraphKeys.BIASES,),
output_collections=(ops.GraphKeys.ACTIVATIONS,),
trainable=True,
weight_regularizer=None,
bias_regularizer=None):
# pylint: disable=g-docstring-has-escape
"""Adds the parameters for a conv2d layer and returns the output.
A neural network convolution layer is generally defined as:
\\\\(y = f(conv2d(w, x) + b)\\\\) where **f** is given by `activation_fn`,
**conv2d** is `tf.nn.conv2d` and `x` has shape
`[batch, height, width, channels]`. The output of this op is of shape
`[batch, out_height, out_width, num_output_channels]`, where `out_width` and
`out_height` are determined by the `padding` argument. See `conv2D` for
details.
This op creates `w` and optionally `b` and adds various summaries that can be
useful for visualizing learning or diagnosing training problems. Bias can be
disabled by setting `bias_init` to `None`.
The variable creation is compatible with `tf.variable_scope` and so can be
reused with `tf.variable_scope` or `tf.make_template`.
Most of the details of variable creation can be controlled by specifying the
initializers (`weight_init` and `bias_init`) and which collections to place
the created variables in (`weight_collections` and `bias_collections`).
A per layer regularization can be specified by setting `weight_regularizer`.
This is only applied to weights and not the bias.
Args:
x: A 4-D input `Tensor`.
num_output_channels: The number of output channels (i.e. the size of the
last dimension of the output).
kernel_size: A length 2 `list` or `tuple` containing the kernel size.
activation_fn: A function that requires a single Tensor that is applied as a
non-linearity.
stride: A length 2 `list` or `tuple` specifying the stride of the sliding
window across the image.
padding: A `string` from: "SAME", "VALID". The type of padding algorithm to
use.
weight_init: An optional initialization. If not specified, uses Xavier
initialization (see `tf.learn.xavier_initializer`).
bias_init: An initializer for the bias, defaults to 0. Set to`None` in order
to disable bias.
name: The name for this operation is used to name operations and to find
variables. If specified it must be unique for this scope, otherwise a
unique name starting with "convolution2d" will be created. See
`tf.variable_op_scope` for details.
weight_collections: List of graph collections to which weights are added.
bias_collections: List of graph collections to which biases are added.
output_collections: List of graph collections to which outputs are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
weight_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for weights.
bias_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for biases.
Returns:
The result of applying a 2-D convolutional layer.
Raises:
ValueError: If `kernel_size` or `stride` are not length 2.
"""
with variable_scope.variable_op_scope([x], name, 'convolution2d'):
num_input_channels = x.get_shape().dims[3].value
if len(kernel_size) != 2:
raise ValueError('kernel_size must be length 2: %d ' % kernel_size)
if len(stride) != 2:
raise ValueError('stride must be length 2: %d' % stride)
stride = [1, stride[0], stride[1], 1]
shape = [kernel_size[0], kernel_size[1], num_input_channels,
num_output_channels]
dtype = x.dtype.base_dtype
weight_collections = set(list(weight_collections or []) +
[ops.GraphKeys.VARIABLES])
w = variable_scope.get_variable('weights',
shape=shape,
dtype=dtype,
initializer=weight_init,
collections=weight_collections,
regularizer=weight_regularizer,
trainable=trainable)
y = nn.conv2d(x, w, stride, padding)
if bias_init is not None:
bias_collections = set(list(bias_collections or []) +
[ops.GraphKeys.VARIABLES])
b = variable_scope.get_variable('bias',
shape=[num_output_channels],
dtype=dtype,
initializer=bias_init,
collections=bias_collections,
regularizer=bias_regularizer,
trainable=trainable)
y = nn.bias_add(y, b)
return _apply_activation(y, activation_fn, output_collections)
# TODO(eiderm): Verify and fix autocomplete in colab (also relu6). # TODO(eiderm): Verify and fix autocomplete in colab (also relu6).
# Simple aliases which remove the activation_fn parameter.
legacy_relu = functools.partial(legacy_fully_connected, activation_fn=nn.relu) legacy_relu = functools.partial(legacy_fully_connected, activation_fn=nn.relu)
legacy_relu6 = functools.partial(legacy_fully_connected, activation_fn=nn.relu6)
# Simple alias for fully_connected which removes the activation_fn parameter.
legacy_linear = functools.partial(legacy_fully_connected, activation_fn=None) legacy_linear = functools.partial(legacy_fully_connected, activation_fn=None)
relu = functools.partial(fully_connected, activation_fn=nn.relu)
relu6 = functools.partial(fully_connected, activation_fn=nn.relu6)
linear = functools.partial(fully_connected, activation_fn=None)
# Simple alias for convolution2d. # Simple alias for convolution2d.
conv2d = convolution2d conv2d = convolution2d

View File

@ -435,15 +435,16 @@ class FCTest(tf.test.TestCase):
def testCreateFC(self): def testCreateFC(self):
height, width = 3, 3 height, width = 3, 3
with self.test_session(): for layer_fn in (tf.contrib.layers.fully_connected, tf.contrib.layers.relu):
inputs = tf.random_uniform((5, height * width * 3), seed=1) with tf.Graph().as_default() as g, self.test_session(g):
output = tf.contrib.layers.fully_connected(inputs, 32) inputs = tf.random_uniform((5, height * width * 3), seed=1)
self.assertEquals(output.op.name, 'fully_connected/Relu') output = layer_fn(inputs, 32)
self.assertListEqual(output.get_shape().as_list(), [5, 32]) self.assertEquals(output.op.name, 'fully_connected/Relu')
weights = tf.contrib.framework.get_variables_by_name('weights')[0] self.assertListEqual(output.get_shape().as_list(), [5, 32])
self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32]) weights = tf.contrib.framework.get_variables_by_name('weights')[0]
biases = tf.contrib.framework.get_variables_by_name('biases')[0] self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32])
self.assertListEqual(biases.get_shape().as_list(), [32]) biases = tf.contrib.framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateFCWithScope(self): def testCreateFCWithScope(self):
height, width = 3, 3 height, width = 3, 3
@ -985,27 +986,6 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
self.assertEqual(0, self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))) len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_relu6_layer_basic_use(self):
output = tf.contrib.layers.legacy_relu6(self.input, 8)
with tf.Session() as sess:
with self.assertRaises(tf.errors.FailedPreconditionError):
sess.run(output)
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 8])
self.assertTrue(np.all(out_value >= 0),
'Relu6 should have all values >= 0.')
self.assertTrue(np.all(out_value <= 6),
'Relu6 should have all values <= 6.')
self.assertEqual(2,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_variable_reuse_with_scope(self): def test_variable_reuse_with_scope(self):
with tf.variable_scope('test') as vs: with tf.variable_scope('test') as vs:
output1 = tf.contrib.layers.legacy_relu(self.input, 8) output1 = tf.contrib.layers.legacy_relu(self.input, 8)
@ -1208,131 +1188,5 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
activation_fn=tf.nn.softmax) activation_fn=tf.nn.softmax)
class LegacyConvolution2dTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
tf.set_random_seed(1234)
self.input = tf.constant(np.arange(2 * 3 * 3 * 4).reshape(
[2, 3, 3, 4]).astype(np.float32))
assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
def test_basic_use(self):
output = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.Session() as sess:
with self.assertRaises(tf.errors.FailedPreconditionError):
sess.run(output)
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 3, 3, 8])
self.assertTrue(np.all(out_value >= 0),
'Relu should have capped all values.')
self.assertEqual(2,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
def test_variable_reuse_with_scope(self):
with tf.variable_scope('test') as vs:
output1 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
output2 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.variable_scope(vs, reuse=True):
output3 = tf.contrib.layers.legacy_convolution2d(self.input,
8, (3, 3),
activation_fn=tf.nn.relu)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
self.assertFalse(np.allclose(out_value1, out_value2))
self.assertAllClose(out_value1, out_value3)
def test_variable_reuse_with_template(self):
tmpl1 = tf.make_template('test',
tf.contrib.layers.legacy_convolution2d,
kernel_size=(3, 3),
num_output_channels=8)
output1 = tmpl1(self.input)
output2 = tmpl1(self.input)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2 = sess.run([output1, output2])
self.assertAllClose(out_value1, out_value2)
def test_custom_initializers(self):
output = tf.contrib.layers.legacy_convolution2d(
self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_init=tf.constant_initializer(2.0),
bias_init=tf.constant_initializer(1.0),
padding='VALID')
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertAllClose(
np.array([[[[1261., 1261.]]], [[[3853., 3853.]]]]), out_value)
def test_custom_collections(self):
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_collections=['unbiased'],
bias_collections=['biased'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
def test_all_custom_collections(self):
tf.contrib.layers.legacy_convolution2d(
self.input,
2, (3, 3),
activation_fn=tf.nn.relu,
weight_collections=['unbiased', 'all'],
bias_collections=['biased', 'all'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
self.assertEquals(2, len(tf.get_collection('all')))
self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
tf.get_collection('all'))
def test_regularizer(self):
cnt = [0]
tensor = tf.constant(5.0)
def test_fn(_):
cnt[0] += 1
return tensor
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
weight_regularizer=test_fn)
self.assertEqual([tensor],
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
def test_no_bias(self):
tf.contrib.layers.legacy_convolution2d(self.input,
2, (3, 3),
bias_init=None)
self.assertEqual(1,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()