Rename weights in Dense layer to kernel, and add base layer aliases for (non_)trainable_weights.

Change: 144752883
This commit is contained in:
Francois Chollet 2017-01-17 14:04:17 -08:00 committed by TensorFlower Gardener
parent 66b5684133
commit 0662eabf9d
5 changed files with 66 additions and 52 deletions

View File

@ -1385,7 +1385,8 @@ def fully_connected(inputs,
if not isinstance(num_outputs, six.integer_types):
raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
layer_variable_getter = _build_variable_getter({'bias': 'biases'})
layer_variable_getter = _build_variable_getter({'bias': 'biases',
'kernel': 'weights'})
with variable_scope.variable_scope(
scope, 'fully_connected', [inputs],
@ -1395,9 +1396,9 @@ def fully_connected(inputs,
units=num_outputs,
activation=None,
use_bias=not normalizer_fn and biases_initializer,
weights_initializer=weights_initializer,
kernel_initializer=weights_initializer,
bias_initializer=biases_initializer,
weights_regularizer=weights_regularizer,
kernel_regularizer=weights_regularizer,
bias_regularizer=biases_regularizer,
activity_regularizer=None,
trainable=trainable,
@ -1408,7 +1409,7 @@ def fully_connected(inputs,
outputs = layer.apply(inputs)
# Add variables to collections.
_add_variable_to_collections(layer.w, variables_collections, 'weights')
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
if layer.bias is not None:
_add_variable_to_collections(layer.bias, variables_collections, 'biases')

View File

@ -1563,7 +1563,7 @@ class FCTest(test.TestCase):
_layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
wd = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name,
'fully_connected/weights/Regularizer/l2_regularizer')
'fully_connected/kernel/Regularizer/l2_regularizer')
sess.run(variables_lib.global_variables_initializer())
self.assertLess(sess.run(wd), 0.4)

View File

@ -120,6 +120,14 @@ class _Layer(object):
def non_trainable_variables(self):
return self._non_trainable_variables if self.trainable else self.variables
@property
def trainable_weights(self):
return self.trainable_variables
@property
def non_trainable_weights(self):
return self.non_trainable_variables
@property
def variables(self):
"""Returns the list of all layer variables/weights.

View File

@ -41,10 +41,12 @@ from tensorflow.python.layers import utils
class Dense(base._Layer): # pylint: disable=protected-access
"""Densely-connected layer class.
This layer implements the operation `outputs = activation(inputs.w + b)`
This layer implements the operation:
`outputs = activation(inputs.kernel + bias)`
Where `activation` is the activation function passed as the `activation`
argument (if not `None`), `w` is a weights matrix created by the layer,
and `b` is a bias vector created by the layer (only if `use_bias` is `True`).
argument (if not `None`), `kernel` is a weights matrix created by the layer,
and `bias` is a bias vector created by the layer
(only if `use_bias` is `True`).
Note: if the input to the layer has a rank greater than 2, then it is
flattened prior to the initial matrix multiply by `w`.
@ -54,9 +56,9 @@ class Dense(base._Layer): # pylint: disable=protected-access
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
weights_initializer: Initializer function for the weight matrix.
kernel_initializer: Initializer function for the weight matrix.
bias_initializer: Initializer function for the bias.
weights_regularizer: Regularizer function for the weight matrix.
kernel_regularizer: Regularizer function for the weight matrix.
bias_regularizer: Regularizer function for the bias.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
@ -70,21 +72,21 @@ class Dense(base._Layer): # pylint: disable=protected-access
units: Python integer, dimensionality of the output space.
activation: Activation function (callable).
use_bias: Boolean, whether the layer uses a bias.
weights_initializer: Initializer instance (or name) for the weight matrix.
kernel_initializer: Initializer instance (or name) for the weight matrix.
bias_initializer: Initializer instance (or name) for the bias.
weights_regularizer: Regularizer instance for the weight matrix (callable)
kernel_regularizer: Regularizer instance for the weight matrix (callable)
bias_regularizer: Regularizer instance for the bias (callable).
activity_regularizer: Regularizer instance for the output (callable)
weights: Weight matrix (TensorFlow variable or tensor).
kernel: Weight matrix (TensorFlow variable or tensor).
bias: Bias vector, if applicable (TensorFlow variable or tensor).
"""
def __init__(self, units,
activation=None,
use_bias=True,
weights_initializer=None,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
weights_regularizer=None,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
@ -94,9 +96,9 @@ class Dense(base._Layer): # pylint: disable=protected-access
self.units = units
self.activation = activation
self.use_bias = use_bias
self.weights_initializer = weights_initializer
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.weights_regularizer = weights_regularizer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
self.activity_regularizer = activity_regularizer
@ -113,12 +115,12 @@ class Dense(base._Layer): # pylint: disable=protected-access
# weight of the layer. If the layer is not trainable
# (self.trainable = False), the variable will not be added to
# tf.trainable_variables(), and self.trainable_weights will be empty.
self.w = vs.get_variable('weights',
shape=[input_shape[-1].value, self.units],
initializer=self.weights_initializer,
regularizer=self.weights_regularizer,
dtype=self.dtype,
trainable=True)
self.kernel = vs.get_variable('kernel',
shape=[input_shape[-1].value, self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = vs.get_variable('bias',
shape=[self.units,],
@ -140,7 +142,7 @@ class Dense(base._Layer): # pylint: disable=protected-access
output_shape_tensor = array_ops.stack(output_shape_tensors)
inputs = array_ops.reshape(inputs, [-1, input_dim])
outputs = standard_ops.matmul(inputs, self.w)
outputs = standard_ops.matmul(inputs, self.kernel)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
@ -158,9 +160,9 @@ def dense(
inputs, units,
activation=None,
use_bias=True,
weights_initializer=None,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
weights_regularizer=None,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
@ -168,10 +170,12 @@ def dense(
reuse=False):
"""Functional interface for the densely-connected layer.
This layer implements the operation `outputs = activation(inputs.w + b)`
This layer implements the operation:
`outputs = activation(inputs.kernel + bias)`
Where `activation` is the activation function passed as the `activation`
argument (if not `None`), `w` is a weights matrix created by the layer,
and `b` is a bias vector created by the layer (only if `use_bias` is `True`).
argument (if not `None`), `kernel` is a weights matrix created by the layer,
and `bias` is a bias vector created by the layer
(only if `use_bias` is `True`).
Note: if the `inputs` tensor has a rank greater than 2, then it is
flattened prior to the initial matrix multiply by `w`.
@ -182,9 +186,9 @@ def dense(
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
weights_initializer: Initializer function for the weight matrix.
kernel_initializer: Initializer function for the weight matrix.
bias_initializer: Initializer function for the bias.
weights_regularizer: Regularizer function for the weight matrix.
kernel_regularizer: Regularizer function for the weight matrix.
bias_regularizer: Regularizer function for the bias.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
@ -199,9 +203,9 @@ def dense(
layer = Dense(units,
activation=activation,
use_bias=use_bias,
weights_initializer=weights_initializer,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
weights_regularizer=weights_regularizer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,

View File

@ -39,7 +39,7 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
self.assertEqual(dense.units, 2)
self.assertEqual(dense.activation, nn_ops.relu)
self.assertEqual(dense.weights_regularizer, None)
self.assertEqual(dense.kernel_regularizer, None)
self.assertEqual(dense.bias_regularizer, None)
self.assertEqual(dense.activity_regularizer, None)
self.assertEqual(dense.use_bias, True)
@ -55,36 +55,37 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
inputs = random_ops.random_uniform((5, 2), seed=1)
_ = dense(inputs)
self.assertListEqual(dense.variables, [dense.w, dense.bias])
self.assertListEqual(dense.trainable_variables, [dense.w, dense.bias])
self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias])
self.assertListEqual(dense.non_trainable_variables, [])
self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias])
self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias])
self.assertListEqual(dense._non_trainable_variables, [])
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
self.assertEqual(dense.w.name, 'my_dense/weights:0')
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
self.assertEqual(dense.bias.name, 'my_dense/bias:0')
def testNoBias(self):
dense = core_layers.Dense(2, use_bias=False, name='my_dense')
inputs = random_ops.random_uniform((5, 2), seed=1)
_ = dense(inputs)
self.assertListEqual(dense.variables, [dense.w])
self.assertListEqual(dense.trainable_variables, [dense.w])
self.assertListEqual(dense.variables, [dense.kernel])
self.assertListEqual(dense.trainable_variables, [dense.kernel])
self.assertListEqual(dense.non_trainable_variables, [])
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
self.assertEqual(dense.w.name, 'my_dense/weights:0')
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
self.assertEqual(dense.bias, None)
def testNonTrainable(self):
dense = core_layers.Dense(2, trainable=False, name='my_dense')
inputs = random_ops.random_uniform((5, 2), seed=1)
_ = dense(inputs)
self.assertListEqual(dense.variables, [dense.w, dense.bias])
self.assertListEqual(dense.non_trainable_variables, [dense.w, dense.bias])
self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
self.assertListEqual(dense.non_trainable_variables,
[dense.kernel, dense.bias])
self.assertListEqual(dense.trainable_variables, [])
self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias])
self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias])
self.assertListEqual(dense._non_trainable_variables, [])
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
@ -149,25 +150,25 @@ class DenseTest(test.TestCase):
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(dense.losses, loss_keys)
def testWeightsRegularizer(self):
def testKernelRegularizer(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
dense = core_layers.Dense(
2, name='my_dense', weights_regularizer=regularizer)
2, name='my_dense', kernel_regularizer=regularizer)
inputs = random_ops.random_uniform((5, 3), seed=1)
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(dense.losses, loss_keys)
def testWeightsRegularizerWithReuse(self):
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
inputs = random_ops.random_uniform((5, 3), seed=1)
_ = core_layers.dense(
inputs, 2, name='my_dense', weights_regularizer=regularizer)
inputs, 2, name='my_dense', kernel_regularizer=regularizer)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
_ = core_layers.dense(
inputs, 2, name='my_dense', weights_regularizer=regularizer, reuse=True)
inputs, 2, name='my_dense', kernel_regularizer=regularizer, reuse=True)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
@ -237,17 +238,17 @@ class DenseTest(test.TestCase):
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2, name='my_dense')
var = variables.trainable_variables()[0]
self.assertEqual(var.name, 'test/my_dense/weights:0')
self.assertEqual(var.name, 'test/my_dense/kernel:0')
with variable_scope.variable_scope('test1') as scope:
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2, name=scope)
var = variables.trainable_variables()[2]
self.assertEqual(var.name, 'test1/weights:0')
self.assertEqual(var.name, 'test1/kernel:0')
with variable_scope.variable_scope('test2'):
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2)
var = variables.trainable_variables()[4]
self.assertEqual(var.name, 'test2/dense/weights:0')
self.assertEqual(var.name, 'test2/dense/kernel:0')
class DropoutTest(test.TestCase):