Rename weights
in Dense
layer to kernel
, and add base layer aliases for (non_)trainable_weights
.
Change: 144752883
This commit is contained in:
parent
66b5684133
commit
0662eabf9d
@ -1385,7 +1385,8 @@ def fully_connected(inputs,
|
||||
if not isinstance(num_outputs, six.integer_types):
|
||||
raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
|
||||
|
||||
layer_variable_getter = _build_variable_getter({'bias': 'biases'})
|
||||
layer_variable_getter = _build_variable_getter({'bias': 'biases',
|
||||
'kernel': 'weights'})
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
scope, 'fully_connected', [inputs],
|
||||
@ -1395,9 +1396,9 @@ def fully_connected(inputs,
|
||||
units=num_outputs,
|
||||
activation=None,
|
||||
use_bias=not normalizer_fn and biases_initializer,
|
||||
weights_initializer=weights_initializer,
|
||||
kernel_initializer=weights_initializer,
|
||||
bias_initializer=biases_initializer,
|
||||
weights_regularizer=weights_regularizer,
|
||||
kernel_regularizer=weights_regularizer,
|
||||
bias_regularizer=biases_regularizer,
|
||||
activity_regularizer=None,
|
||||
trainable=trainable,
|
||||
@ -1408,7 +1409,7 @@ def fully_connected(inputs,
|
||||
outputs = layer.apply(inputs)
|
||||
|
||||
# Add variables to collections.
|
||||
_add_variable_to_collections(layer.w, variables_collections, 'weights')
|
||||
_add_variable_to_collections(layer.kernel, variables_collections, 'weights')
|
||||
if layer.bias is not None:
|
||||
_add_variable_to_collections(layer.bias, variables_collections, 'biases')
|
||||
|
||||
|
@ -1563,7 +1563,7 @@ class FCTest(test.TestCase):
|
||||
_layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
|
||||
wd = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
|
||||
self.assertEqual(wd.op.name,
|
||||
'fully_connected/weights/Regularizer/l2_regularizer')
|
||||
'fully_connected/kernel/Regularizer/l2_regularizer')
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
self.assertLess(sess.run(wd), 0.4)
|
||||
|
||||
|
@ -120,6 +120,14 @@ class _Layer(object):
|
||||
def non_trainable_variables(self):
|
||||
return self._non_trainable_variables if self.trainable else self.variables
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
return self.trainable_variables
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
return self.non_trainable_variables
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""Returns the list of all layer variables/weights.
|
||||
|
@ -41,10 +41,12 @@ from tensorflow.python.layers import utils
|
||||
class Dense(base._Layer): # pylint: disable=protected-access
|
||||
"""Densely-connected layer class.
|
||||
|
||||
This layer implements the operation `outputs = activation(inputs.w + b)`
|
||||
This layer implements the operation:
|
||||
`outputs = activation(inputs.kernel + bias)`
|
||||
Where `activation` is the activation function passed as the `activation`
|
||||
argument (if not `None`), `w` is a weights matrix created by the layer,
|
||||
and `b` is a bias vector created by the layer (only if `use_bias` is `True`).
|
||||
argument (if not `None`), `kernel` is a weights matrix created by the layer,
|
||||
and `bias` is a bias vector created by the layer
|
||||
(only if `use_bias` is `True`).
|
||||
|
||||
Note: if the input to the layer has a rank greater than 2, then it is
|
||||
flattened prior to the initial matrix multiply by `w`.
|
||||
@ -54,9 +56,9 @@ class Dense(base._Layer): # pylint: disable=protected-access
|
||||
activation: Activation function (callable). Set it to None to maintain a
|
||||
linear activation.
|
||||
use_bias: Boolean, whether the layer uses a bias.
|
||||
weights_initializer: Initializer function for the weight matrix.
|
||||
kernel_initializer: Initializer function for the weight matrix.
|
||||
bias_initializer: Initializer function for the bias.
|
||||
weights_regularizer: Regularizer function for the weight matrix.
|
||||
kernel_regularizer: Regularizer function for the weight matrix.
|
||||
bias_regularizer: Regularizer function for the bias.
|
||||
activity_regularizer: Regularizer function for the output.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
@ -70,21 +72,21 @@ class Dense(base._Layer): # pylint: disable=protected-access
|
||||
units: Python integer, dimensionality of the output space.
|
||||
activation: Activation function (callable).
|
||||
use_bias: Boolean, whether the layer uses a bias.
|
||||
weights_initializer: Initializer instance (or name) for the weight matrix.
|
||||
kernel_initializer: Initializer instance (or name) for the weight matrix.
|
||||
bias_initializer: Initializer instance (or name) for the bias.
|
||||
weights_regularizer: Regularizer instance for the weight matrix (callable)
|
||||
kernel_regularizer: Regularizer instance for the weight matrix (callable)
|
||||
bias_regularizer: Regularizer instance for the bias (callable).
|
||||
activity_regularizer: Regularizer instance for the output (callable)
|
||||
weights: Weight matrix (TensorFlow variable or tensor).
|
||||
kernel: Weight matrix (TensorFlow variable or tensor).
|
||||
bias: Bias vector, if applicable (TensorFlow variable or tensor).
|
||||
"""
|
||||
|
||||
def __init__(self, units,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
weights_initializer=None,
|
||||
kernel_initializer=None,
|
||||
bias_initializer=init_ops.zeros_initializer(),
|
||||
weights_regularizer=None,
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
trainable=True,
|
||||
@ -94,9 +96,9 @@ class Dense(base._Layer): # pylint: disable=protected-access
|
||||
self.units = units
|
||||
self.activation = activation
|
||||
self.use_bias = use_bias
|
||||
self.weights_initializer = weights_initializer
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.weights_regularizer = weights_regularizer
|
||||
self.kernel_regularizer = kernel_regularizer
|
||||
self.bias_regularizer = bias_regularizer
|
||||
self.activity_regularizer = activity_regularizer
|
||||
|
||||
@ -113,12 +115,12 @@ class Dense(base._Layer): # pylint: disable=protected-access
|
||||
# weight of the layer. If the layer is not trainable
|
||||
# (self.trainable = False), the variable will not be added to
|
||||
# tf.trainable_variables(), and self.trainable_weights will be empty.
|
||||
self.w = vs.get_variable('weights',
|
||||
shape=[input_shape[-1].value, self.units],
|
||||
initializer=self.weights_initializer,
|
||||
regularizer=self.weights_regularizer,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
self.kernel = vs.get_variable('kernel',
|
||||
shape=[input_shape[-1].value, self.units],
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
if self.use_bias:
|
||||
self.bias = vs.get_variable('bias',
|
||||
shape=[self.units,],
|
||||
@ -140,7 +142,7 @@ class Dense(base._Layer): # pylint: disable=protected-access
|
||||
output_shape_tensor = array_ops.stack(output_shape_tensors)
|
||||
inputs = array_ops.reshape(inputs, [-1, input_dim])
|
||||
|
||||
outputs = standard_ops.matmul(inputs, self.w)
|
||||
outputs = standard_ops.matmul(inputs, self.kernel)
|
||||
if self.use_bias:
|
||||
outputs = nn.bias_add(outputs, self.bias)
|
||||
|
||||
@ -158,9 +160,9 @@ def dense(
|
||||
inputs, units,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
weights_initializer=None,
|
||||
kernel_initializer=None,
|
||||
bias_initializer=init_ops.zeros_initializer(),
|
||||
weights_regularizer=None,
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
trainable=True,
|
||||
@ -168,10 +170,12 @@ def dense(
|
||||
reuse=False):
|
||||
"""Functional interface for the densely-connected layer.
|
||||
|
||||
This layer implements the operation `outputs = activation(inputs.w + b)`
|
||||
This layer implements the operation:
|
||||
`outputs = activation(inputs.kernel + bias)`
|
||||
Where `activation` is the activation function passed as the `activation`
|
||||
argument (if not `None`), `w` is a weights matrix created by the layer,
|
||||
and `b` is a bias vector created by the layer (only if `use_bias` is `True`).
|
||||
argument (if not `None`), `kernel` is a weights matrix created by the layer,
|
||||
and `bias` is a bias vector created by the layer
|
||||
(only if `use_bias` is `True`).
|
||||
|
||||
Note: if the `inputs` tensor has a rank greater than 2, then it is
|
||||
flattened prior to the initial matrix multiply by `w`.
|
||||
@ -182,9 +186,9 @@ def dense(
|
||||
activation: Activation function (callable). Set it to None to maintain a
|
||||
linear activation.
|
||||
use_bias: Boolean, whether the layer uses a bias.
|
||||
weights_initializer: Initializer function for the weight matrix.
|
||||
kernel_initializer: Initializer function for the weight matrix.
|
||||
bias_initializer: Initializer function for the bias.
|
||||
weights_regularizer: Regularizer function for the weight matrix.
|
||||
kernel_regularizer: Regularizer function for the weight matrix.
|
||||
bias_regularizer: Regularizer function for the bias.
|
||||
activity_regularizer: Regularizer function for the output.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
@ -199,9 +203,9 @@ def dense(
|
||||
layer = Dense(units,
|
||||
activation=activation,
|
||||
use_bias=use_bias,
|
||||
weights_initializer=weights_initializer,
|
||||
kernel_initializer=kernel_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
weights_regularizer=weights_regularizer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
|
@ -39,7 +39,7 @@ class DenseTest(test.TestCase):
|
||||
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
|
||||
self.assertEqual(dense.units, 2)
|
||||
self.assertEqual(dense.activation, nn_ops.relu)
|
||||
self.assertEqual(dense.weights_regularizer, None)
|
||||
self.assertEqual(dense.kernel_regularizer, None)
|
||||
self.assertEqual(dense.bias_regularizer, None)
|
||||
self.assertEqual(dense.activity_regularizer, None)
|
||||
self.assertEqual(dense.use_bias, True)
|
||||
@ -55,36 +55,37 @@ class DenseTest(test.TestCase):
|
||||
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
|
||||
inputs = random_ops.random_uniform((5, 2), seed=1)
|
||||
_ = dense(inputs)
|
||||
self.assertListEqual(dense.variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense.trainable_variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.non_trainable_variables, [])
|
||||
self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense._non_trainable_variables, [])
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
||||
self.assertEqual(dense.w.name, 'my_dense/weights:0')
|
||||
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
|
||||
self.assertEqual(dense.bias.name, 'my_dense/bias:0')
|
||||
|
||||
def testNoBias(self):
|
||||
dense = core_layers.Dense(2, use_bias=False, name='my_dense')
|
||||
inputs = random_ops.random_uniform((5, 2), seed=1)
|
||||
_ = dense(inputs)
|
||||
self.assertListEqual(dense.variables, [dense.w])
|
||||
self.assertListEqual(dense.trainable_variables, [dense.w])
|
||||
self.assertListEqual(dense.variables, [dense.kernel])
|
||||
self.assertListEqual(dense.trainable_variables, [dense.kernel])
|
||||
self.assertListEqual(dense.non_trainable_variables, [])
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
|
||||
self.assertEqual(dense.w.name, 'my_dense/weights:0')
|
||||
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
|
||||
self.assertEqual(dense.bias, None)
|
||||
|
||||
def testNonTrainable(self):
|
||||
dense = core_layers.Dense(2, trainable=False, name='my_dense')
|
||||
inputs = random_ops.random_uniform((5, 2), seed=1)
|
||||
_ = dense(inputs)
|
||||
self.assertListEqual(dense.variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense.non_trainable_variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.non_trainable_variables,
|
||||
[dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.trainable_variables, [])
|
||||
self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias])
|
||||
self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense._non_trainable_variables, [])
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
|
||||
@ -149,25 +150,25 @@ class DenseTest(test.TestCase):
|
||||
self.assertEqual(len(loss_keys), 1)
|
||||
self.assertListEqual(dense.losses, loss_keys)
|
||||
|
||||
def testWeightsRegularizer(self):
|
||||
def testKernelRegularizer(self):
|
||||
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
|
||||
dense = core_layers.Dense(
|
||||
2, name='my_dense', weights_regularizer=regularizer)
|
||||
2, name='my_dense', kernel_regularizer=regularizer)
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
_ = dense(inputs)
|
||||
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
self.assertEqual(len(loss_keys), 1)
|
||||
self.assertListEqual(dense.losses, loss_keys)
|
||||
|
||||
def testWeightsRegularizerWithReuse(self):
|
||||
def testKernelRegularizerWithReuse(self):
|
||||
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
_ = core_layers.dense(
|
||||
inputs, 2, name='my_dense', weights_regularizer=regularizer)
|
||||
inputs, 2, name='my_dense', kernel_regularizer=regularizer)
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
|
||||
_ = core_layers.dense(
|
||||
inputs, 2, name='my_dense', weights_regularizer=regularizer, reuse=True)
|
||||
inputs, 2, name='my_dense', kernel_regularizer=regularizer, reuse=True)
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
|
||||
|
||||
@ -237,17 +238,17 @@ class DenseTest(test.TestCase):
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
core_layers.dense(inputs, 2, name='my_dense')
|
||||
var = variables.trainable_variables()[0]
|
||||
self.assertEqual(var.name, 'test/my_dense/weights:0')
|
||||
self.assertEqual(var.name, 'test/my_dense/kernel:0')
|
||||
with variable_scope.variable_scope('test1') as scope:
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
core_layers.dense(inputs, 2, name=scope)
|
||||
var = variables.trainable_variables()[2]
|
||||
self.assertEqual(var.name, 'test1/weights:0')
|
||||
self.assertEqual(var.name, 'test1/kernel:0')
|
||||
with variable_scope.variable_scope('test2'):
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
core_layers.dense(inputs, 2)
|
||||
var = variables.trainable_variables()[4]
|
||||
self.assertEqual(var.name, 'test2/dense/weights:0')
|
||||
self.assertEqual(var.name, 'test2/dense/kernel:0')
|
||||
|
||||
|
||||
class DropoutTest(test.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user