From 0662eabf9d6d670bd9a741ea3a3eb0c9f0005850 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 17 Jan 2017 14:04:17 -0800 Subject: [PATCH] Rename `weights` in `Dense` layer to `kernel`, and add base layer aliases for `(non_)trainable_weights`. Change: 144752883 --- .../contrib/layers/python/layers/layers.py | 9 +-- .../layers/python/layers/layers_test.py | 2 +- tensorflow/python/layers/base.py | 8 +++ tensorflow/python/layers/core.py | 60 ++++++++++--------- tensorflow/python/layers/core_test.py | 39 ++++++------ 5 files changed, 66 insertions(+), 52 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2673495b904..e47342f9663 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1385,7 +1385,8 @@ def fully_connected(inputs, if not isinstance(num_outputs, six.integer_types): raise ValueError('num_outputs should be int or long, got %s.', num_outputs) - layer_variable_getter = _build_variable_getter({'bias': 'biases'}) + layer_variable_getter = _build_variable_getter({'bias': 'biases', + 'kernel': 'weights'}) with variable_scope.variable_scope( scope, 'fully_connected', [inputs], @@ -1395,9 +1396,9 @@ def fully_connected(inputs, units=num_outputs, activation=None, use_bias=not normalizer_fn and biases_initializer, - weights_initializer=weights_initializer, + kernel_initializer=weights_initializer, bias_initializer=biases_initializer, - weights_regularizer=weights_regularizer, + kernel_regularizer=weights_regularizer, bias_regularizer=biases_regularizer, activity_regularizer=None, trainable=trainable, @@ -1408,7 +1409,7 @@ def fully_connected(inputs, outputs = layer.apply(inputs) # Add variables to collections. - _add_variable_to_collections(layer.w, variables_collections, 'weights') + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') if layer.bias is not None: _add_variable_to_collections(layer.bias, variables_collections, 'biases') diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index d1b35e33c26..6043d4dc0e3 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1563,7 +1563,7 @@ class FCTest(test.TestCase): _layers.fully_connected(inputs, 32, weights_regularizer=weight_decay) wd = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(wd.op.name, - 'fully_connected/weights/Regularizer/l2_regularizer') + 'fully_connected/kernel/Regularizer/l2_regularizer') sess.run(variables_lib.global_variables_initializer()) self.assertLess(sess.run(wd), 0.4) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 74a6052ff6b..853b08b2a50 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -120,6 +120,14 @@ class _Layer(object): def non_trainable_variables(self): return self._non_trainable_variables if self.trainable else self.variables + @property + def trainable_weights(self): + return self.trainable_variables + + @property + def non_trainable_weights(self): + return self.non_trainable_variables + @property def variables(self): """Returns the list of all layer variables/weights. diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index c662478cccb..b8b6cd97da7 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -41,10 +41,12 @@ from tensorflow.python.layers import utils class Dense(base._Layer): # pylint: disable=protected-access """Densely-connected layer class. - This layer implements the operation `outputs = activation(inputs.w + b)` + This layer implements the operation: + `outputs = activation(inputs.kernel + bias)` Where `activation` is the activation function passed as the `activation` - argument (if not `None`), `w` is a weights matrix created by the layer, - and `b` is a bias vector created by the layer (only if `use_bias` is `True`). + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). Note: if the input to the layer has a rank greater than 2, then it is flattened prior to the initial matrix multiply by `w`. @@ -54,9 +56,9 @@ class Dense(base._Layer): # pylint: disable=protected-access activation: Activation function (callable). Set it to None to maintain a linear activation. use_bias: Boolean, whether the layer uses a bias. - weights_initializer: Initializer function for the weight matrix. + kernel_initializer: Initializer function for the weight matrix. bias_initializer: Initializer function for the bias. - weights_regularizer: Regularizer function for the weight matrix. + kernel_regularizer: Regularizer function for the weight matrix. bias_regularizer: Regularizer function for the bias. activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection @@ -70,21 +72,21 @@ class Dense(base._Layer): # pylint: disable=protected-access units: Python integer, dimensionality of the output space. activation: Activation function (callable). use_bias: Boolean, whether the layer uses a bias. - weights_initializer: Initializer instance (or name) for the weight matrix. + kernel_initializer: Initializer instance (or name) for the weight matrix. bias_initializer: Initializer instance (or name) for the bias. - weights_regularizer: Regularizer instance for the weight matrix (callable) + kernel_regularizer: Regularizer instance for the weight matrix (callable) bias_regularizer: Regularizer instance for the bias (callable). activity_regularizer: Regularizer instance for the output (callable) - weights: Weight matrix (TensorFlow variable or tensor). + kernel: Weight matrix (TensorFlow variable or tensor). bias: Bias vector, if applicable (TensorFlow variable or tensor). """ def __init__(self, units, activation=None, use_bias=True, - weights_initializer=None, + kernel_initializer=None, bias_initializer=init_ops.zeros_initializer(), - weights_regularizer=None, + kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, trainable=True, @@ -94,9 +96,9 @@ class Dense(base._Layer): # pylint: disable=protected-access self.units = units self.activation = activation self.use_bias = use_bias - self.weights_initializer = weights_initializer + self.kernel_initializer = kernel_initializer self.bias_initializer = bias_initializer - self.weights_regularizer = weights_regularizer + self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer self.activity_regularizer = activity_regularizer @@ -113,12 +115,12 @@ class Dense(base._Layer): # pylint: disable=protected-access # weight of the layer. If the layer is not trainable # (self.trainable = False), the variable will not be added to # tf.trainable_variables(), and self.trainable_weights will be empty. - self.w = vs.get_variable('weights', - shape=[input_shape[-1].value, self.units], - initializer=self.weights_initializer, - regularizer=self.weights_regularizer, - dtype=self.dtype, - trainable=True) + self.kernel = vs.get_variable('kernel', + shape=[input_shape[-1].value, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True) if self.use_bias: self.bias = vs.get_variable('bias', shape=[self.units,], @@ -140,7 +142,7 @@ class Dense(base._Layer): # pylint: disable=protected-access output_shape_tensor = array_ops.stack(output_shape_tensors) inputs = array_ops.reshape(inputs, [-1, input_dim]) - outputs = standard_ops.matmul(inputs, self.w) + outputs = standard_ops.matmul(inputs, self.kernel) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) @@ -158,9 +160,9 @@ def dense( inputs, units, activation=None, use_bias=True, - weights_initializer=None, + kernel_initializer=None, bias_initializer=init_ops.zeros_initializer(), - weights_regularizer=None, + kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, trainable=True, @@ -168,10 +170,12 @@ def dense( reuse=False): """Functional interface for the densely-connected layer. - This layer implements the operation `outputs = activation(inputs.w + b)` + This layer implements the operation: + `outputs = activation(inputs.kernel + bias)` Where `activation` is the activation function passed as the `activation` - argument (if not `None`), `w` is a weights matrix created by the layer, - and `b` is a bias vector created by the layer (only if `use_bias` is `True`). + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). Note: if the `inputs` tensor has a rank greater than 2, then it is flattened prior to the initial matrix multiply by `w`. @@ -182,9 +186,9 @@ def dense( activation: Activation function (callable). Set it to None to maintain a linear activation. use_bias: Boolean, whether the layer uses a bias. - weights_initializer: Initializer function for the weight matrix. + kernel_initializer: Initializer function for the weight matrix. bias_initializer: Initializer function for the bias. - weights_regularizer: Regularizer function for the weight matrix. + kernel_regularizer: Regularizer function for the weight matrix. bias_regularizer: Regularizer function for the bias. activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection @@ -199,9 +203,9 @@ def dense( layer = Dense(units, activation=activation, use_bias=use_bias, - weights_initializer=weights_initializer, + kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, - weights_regularizer=weights_regularizer, + kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, trainable=trainable, diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index c1fbe957df6..50e3ebc72f5 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -39,7 +39,7 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') self.assertEqual(dense.units, 2) self.assertEqual(dense.activation, nn_ops.relu) - self.assertEqual(dense.weights_regularizer, None) + self.assertEqual(dense.kernel_regularizer, None) self.assertEqual(dense.bias_regularizer, None) self.assertEqual(dense.activity_regularizer, None) self.assertEqual(dense.use_bias, True) @@ -55,36 +55,37 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) _ = dense(inputs) - self.assertListEqual(dense.variables, [dense.w, dense.bias]) - self.assertListEqual(dense.trainable_variables, [dense.w, dense.bias]) + self.assertListEqual(dense.variables, [dense.kernel, dense.bias]) + self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.non_trainable_variables, []) - self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias]) + self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense._non_trainable_variables, []) self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) - self.assertEqual(dense.w.name, 'my_dense/weights:0') + self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias.name, 'my_dense/bias:0') def testNoBias(self): dense = core_layers.Dense(2, use_bias=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) _ = dense(inputs) - self.assertListEqual(dense.variables, [dense.w]) - self.assertListEqual(dense.trainable_variables, [dense.w]) + self.assertListEqual(dense.variables, [dense.kernel]) + self.assertListEqual(dense.trainable_variables, [dense.kernel]) self.assertListEqual(dense.non_trainable_variables, []) self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) - self.assertEqual(dense.w.name, 'my_dense/weights:0') + self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias, None) def testNonTrainable(self): dense = core_layers.Dense(2, trainable=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) _ = dense(inputs) - self.assertListEqual(dense.variables, [dense.w, dense.bias]) - self.assertListEqual(dense.non_trainable_variables, [dense.w, dense.bias]) + self.assertListEqual(dense.variables, [dense.kernel, dense.bias]) + self.assertListEqual(dense.non_trainable_variables, + [dense.kernel, dense.bias]) self.assertListEqual(dense.trainable_variables, []) - self.assertListEqual(dense._trainable_variables, [dense.w, dense.bias]) + self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense._non_trainable_variables, []) self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) @@ -149,25 +150,25 @@ class DenseTest(test.TestCase): self.assertEqual(len(loss_keys), 1) self.assertListEqual(dense.losses, loss_keys) - def testWeightsRegularizer(self): + def testKernelRegularizer(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 dense = core_layers.Dense( - 2, name='my_dense', weights_regularizer=regularizer) + 2, name='my_dense', kernel_regularizer=regularizer) inputs = random_ops.random_uniform((5, 3), seed=1) _ = dense(inputs) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) self.assertListEqual(dense.losses, loss_keys) - def testWeightsRegularizerWithReuse(self): + def testKernelRegularizerWithReuse(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 inputs = random_ops.random_uniform((5, 3), seed=1) _ = core_layers.dense( - inputs, 2, name='my_dense', weights_regularizer=regularizer) + inputs, 2, name='my_dense', kernel_regularizer=regularizer) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) _ = core_layers.dense( - inputs, 2, name='my_dense', weights_regularizer=regularizer, reuse=True) + inputs, 2, name='my_dense', kernel_regularizer=regularizer, reuse=True) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) @@ -237,17 +238,17 @@ class DenseTest(test.TestCase): inputs = random_ops.random_uniform((5, 3), seed=1) core_layers.dense(inputs, 2, name='my_dense') var = variables.trainable_variables()[0] - self.assertEqual(var.name, 'test/my_dense/weights:0') + self.assertEqual(var.name, 'test/my_dense/kernel:0') with variable_scope.variable_scope('test1') as scope: inputs = random_ops.random_uniform((5, 3), seed=1) core_layers.dense(inputs, 2, name=scope) var = variables.trainable_variables()[2] - self.assertEqual(var.name, 'test1/weights:0') + self.assertEqual(var.name, 'test1/kernel:0') with variable_scope.variable_scope('test2'): inputs = random_ops.random_uniform((5, 3), seed=1) core_layers.dense(inputs, 2) var = variables.trainable_variables()[4] - self.assertEqual(var.name, 'test2/dense/weights:0') + self.assertEqual(var.name, 'test2/dense/kernel:0') class DropoutTest(test.TestCase):