Add trainable arg to fully_connected.
Add linear, relu, relu6. Remove legacy_relu6 and legacy_convolution2d. Change: 123898222
This commit is contained in:
		
							parent
							
								
									6cc37af4fb
								
							
						
					
					
						commit
						f6acee434c
					
				| @ -48,14 +48,15 @@ __all__ = ['avg_pool2d', | ||||
|            'dropout', | ||||
|            'flatten', | ||||
|            'fully_connected', | ||||
|            'linear', | ||||
|            'max_pool2d', | ||||
|            'one_hot_encoding', | ||||
|            'relu', | ||||
|            'relu6', | ||||
|            'stack', | ||||
|            'legacy_convolution2d', | ||||
|            'legacy_fully_connected', | ||||
|            'legacy_linear', | ||||
|            'legacy_relu', | ||||
|            'legacy_relu6'] | ||||
|            'legacy_relu'] | ||||
| 
 | ||||
| 
 | ||||
| @add_arg_scope | ||||
| @ -107,6 +108,7 @@ def batch_norm(inputs, | ||||
|                reuse=None, | ||||
|                variables_collections=None, | ||||
|                outputs_collections=None, | ||||
|                trainable=True, | ||||
|                scope=None): | ||||
|   """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. | ||||
| 
 | ||||
| @ -139,6 +141,8 @@ def batch_norm(inputs, | ||||
|       able to reuse the layer scope must be given. | ||||
|     variables_collections: optional collections for the variables. | ||||
|     outputs_collections: collections to add the outputs. | ||||
|     trainable: If `True` also add variables to the graph collection | ||||
|       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||||
|     scope: Optional scope for `variable_op_scope`. | ||||
| 
 | ||||
|   Returns: | ||||
| @ -160,7 +164,8 @@ def batch_norm(inputs, | ||||
|                                       shape=params_shape, | ||||
|                                       dtype=dtype, | ||||
|                                       initializer=init_ops.zeros_initializer, | ||||
|                                       collections=beta_collections) | ||||
|                                       collections=beta_collections, | ||||
|                                       trainable=trainable) | ||||
|     if scale: | ||||
|       gamma_collections = utils.get_variable_collections(variables_collections, | ||||
|                                                          'gamma') | ||||
| @ -168,7 +173,8 @@ def batch_norm(inputs, | ||||
|                                        shape=params_shape, | ||||
|                                        dtype=dtype, | ||||
|                                        initializer=init_ops.ones_initializer, | ||||
|                                        collections=gamma_collections) | ||||
|                                        collections=gamma_collections, | ||||
|                                        trainable=trainable) | ||||
|     # Create moving_mean and moving_variance variables and add them to the | ||||
|     # appropiate collections. | ||||
|     moving_mean_collections = utils.get_variable_collections( | ||||
| @ -226,6 +232,7 @@ def bias_add(inputs, | ||||
|              reuse=None, | ||||
|              variables_collections=None, | ||||
|              outputs_collections=None, | ||||
|              trainable=True, | ||||
|              scope=None): | ||||
|   """Adds a bias to the inputs. | ||||
| 
 | ||||
| @ -242,6 +249,8 @@ def bias_add(inputs, | ||||
|       able to reuse the layer scope must be given. | ||||
|     variables_collections: optional collections for the variables. | ||||
|     outputs_collections: collections to add the outputs. | ||||
|     trainable: If `True` also add variables to the graph collection | ||||
|       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||||
|     scope: Optional scope for variable_op_scope. | ||||
| 
 | ||||
|   Returns: | ||||
| @ -258,7 +267,8 @@ def bias_add(inputs, | ||||
|                                       dtype=dtype, | ||||
|                                       initializer=initializer, | ||||
|                                       regularizer=regularizer, | ||||
|                                       collections=biases_collections) | ||||
|                                       collections=biases_collections, | ||||
|                                       trainable=trainable) | ||||
|     outputs = nn.bias_add(inputs, biases) | ||||
|     if activation_fn: | ||||
|       outputs = activation_fn(outputs) | ||||
| @ -281,6 +291,7 @@ def convolution2d(inputs, | ||||
|                   reuse=None, | ||||
|                   variables_collections=None, | ||||
|                   outputs_collections=None, | ||||
|                   trainable=True, | ||||
|                   scope=None): | ||||
|   """Adds a 2D convolution followed by an optional batch_norm layer. | ||||
| 
 | ||||
| @ -315,6 +326,8 @@ def convolution2d(inputs, | ||||
|     variables_collections: optional list of collections for all the variables or | ||||
|       a dictionay containing a different list of collection per variable. | ||||
|     outputs_collections: collection to add the outputs. | ||||
|     trainable: If `True` also add variables to the graph collection | ||||
|       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||||
|     scope: Optional scope for `variable_op_scope`. | ||||
| 
 | ||||
|   Returns: | ||||
| @ -336,7 +349,8 @@ def convolution2d(inputs, | ||||
|                                        dtype=dtype, | ||||
|                                        initializer=weights_initializer, | ||||
|                                        regularizer=weights_regularizer, | ||||
|                                        collections=weights_collections) | ||||
|                                        collections=weights_collections, | ||||
|                                        trainable=trainable) | ||||
|     outputs = nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1], | ||||
|                         padding=padding) | ||||
|     if normalizer_fn: | ||||
| @ -351,7 +365,8 @@ def convolution2d(inputs, | ||||
|                                           dtype=dtype, | ||||
|                                           initializer=biases_initializer, | ||||
|                                           regularizer=biases_regularizer, | ||||
|                                           collections=biases_collections) | ||||
|                                           collections=biases_collections, | ||||
|                                           trainable=trainable) | ||||
|         outputs = nn.bias_add(outputs, biases) | ||||
|     if activation_fn: | ||||
|       outputs = activation_fn(outputs) | ||||
| @ -435,6 +450,7 @@ def fully_connected(inputs, | ||||
|                     reuse=None, | ||||
|                     variables_collections=None, | ||||
|                     outputs_collections=None, | ||||
|                     trainable=True, | ||||
|                     scope=None): | ||||
|   """Adds a fully connected layer. | ||||
| 
 | ||||
| @ -465,8 +481,10 @@ def fully_connected(inputs, | ||||
|     reuse: whether or not the layer and its variables should be reused. To be | ||||
|       able to reuse the layer scope must be given. | ||||
|     variables_collections: Optional list of collections for all the variables or | ||||
|       a dictionay containing a different list of collection per variable. | ||||
|       a dictionary containing a different list of collections per variable. | ||||
|     outputs_collections: collection to add the outputs. | ||||
|     trainable: If `True` also add variables to the graph collection | ||||
|       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||||
|     scope: Optional scope for variable_op_scope. | ||||
| 
 | ||||
|   Returns: | ||||
| @ -496,7 +514,8 @@ def fully_connected(inputs, | ||||
|                                        dtype=dtype, | ||||
|                                        initializer=weights_initializer, | ||||
|                                        regularizer=weights_regularizer, | ||||
|                                        collections=weights_collections) | ||||
|                                        collections=weights_collections, | ||||
|                                        trainable=trainable) | ||||
|     if len(static_shape) > 2: | ||||
|       # Reshape inputs | ||||
|       inputs = array_ops.reshape(inputs, [-1, num_input_units]) | ||||
| @ -513,7 +532,8 @@ def fully_connected(inputs, | ||||
|                                           dtype=dtype, | ||||
|                                           initializer=biases_initializer, | ||||
|                                           regularizer=biases_regularizer, | ||||
|                                           collections=biases_collections) | ||||
|                                           collections=biases_collections, | ||||
|                                           trainable=trainable) | ||||
|         outputs = nn.bias_add(outputs, biases) | ||||
|     if len(static_shape) > 2: | ||||
|       # Reshape back outputs | ||||
| @ -773,127 +793,13 @@ def legacy_fully_connected(x, | ||||
|     return _apply_activation(y, activation_fn, output_collections) | ||||
| 
 | ||||
| 
 | ||||
| def legacy_convolution2d(x, | ||||
|                          num_output_channels, | ||||
|                          kernel_size, | ||||
|                          activation_fn=None, | ||||
|                          stride=(1, 1), | ||||
|                          padding='SAME', | ||||
|                          weight_init=initializers.xavier_initializer_conv2d(), | ||||
|                          bias_init=standard_ops.zeros_initializer, | ||||
|                          name=None, | ||||
|                          weight_collections=(ops.GraphKeys.WEIGHTS,), | ||||
|                          bias_collections=(ops.GraphKeys.BIASES,), | ||||
|                          output_collections=(ops.GraphKeys.ACTIVATIONS,), | ||||
|                          trainable=True, | ||||
|                          weight_regularizer=None, | ||||
|                          bias_regularizer=None): | ||||
|   # pylint: disable=g-docstring-has-escape | ||||
|   """Adds the parameters for a conv2d layer and returns the output. | ||||
| 
 | ||||
|   A neural network convolution layer is generally defined as: | ||||
|   \\\\(y = f(conv2d(w, x) + b)\\\\) where **f** is given by `activation_fn`, | ||||
|   **conv2d** is `tf.nn.conv2d` and `x` has shape | ||||
|   `[batch, height, width, channels]`. The output of this op is of shape | ||||
|   `[batch, out_height, out_width, num_output_channels]`, where `out_width` and | ||||
|   `out_height` are determined by the `padding` argument. See `conv2D` for | ||||
|   details. | ||||
| 
 | ||||
|   This op creates `w` and optionally `b` and adds various summaries that can be | ||||
|   useful for visualizing learning or diagnosing training problems. Bias can be | ||||
|   disabled by setting `bias_init` to `None`. | ||||
| 
 | ||||
|   The variable creation is compatible with `tf.variable_scope` and so can be | ||||
|   reused with `tf.variable_scope` or `tf.make_template`. | ||||
| 
 | ||||
|   Most of the details of variable creation can be controlled by specifying the | ||||
|   initializers (`weight_init` and `bias_init`) and which collections to place | ||||
|   the created variables in (`weight_collections` and `bias_collections`). | ||||
| 
 | ||||
|   A per layer regularization can be specified by setting `weight_regularizer`. | ||||
|   This is only applied to weights and not the bias. | ||||
| 
 | ||||
|   Args: | ||||
|     x: A 4-D input `Tensor`. | ||||
|     num_output_channels: The number of output channels (i.e. the size of the | ||||
|       last dimension of the output). | ||||
|     kernel_size: A length 2 `list` or `tuple` containing the kernel size. | ||||
|     activation_fn: A function that requires a single Tensor that is applied as a | ||||
|       non-linearity. | ||||
|     stride: A length 2 `list` or `tuple` specifying the stride of the sliding | ||||
|       window across the image. | ||||
|     padding: A `string` from: "SAME", "VALID". The type of padding algorithm to | ||||
|       use. | ||||
|     weight_init: An optional initialization. If not specified, uses Xavier | ||||
|       initialization (see `tf.learn.xavier_initializer`). | ||||
|     bias_init: An initializer for the bias, defaults to 0. Set to`None` in order | ||||
|       to disable bias. | ||||
|     name: The name for this operation is used to name operations and to find | ||||
|       variables. If specified it must be unique for this scope, otherwise a | ||||
|       unique name starting with "convolution2d" will be created.  See | ||||
|       `tf.variable_op_scope` for details. | ||||
|     weight_collections: List of graph collections to which weights are added. | ||||
|     bias_collections: List of graph collections to which biases are added. | ||||
|     output_collections: List of graph collections to which outputs are added. | ||||
|     trainable: If `True` also add variables to the graph collection | ||||
|       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||||
|     weight_regularizer: A regularizer like the result of | ||||
|       `l1_regularizer` or `l2_regularizer`. Used for weights. | ||||
|     bias_regularizer: A regularizer like the result of | ||||
|       `l1_regularizer` or `l2_regularizer`. Used for biases. | ||||
| 
 | ||||
|   Returns: | ||||
|     The result of applying a 2-D convolutional layer. | ||||
| 
 | ||||
|   Raises: | ||||
|     ValueError: If `kernel_size` or `stride` are not length 2. | ||||
|   """ | ||||
|   with variable_scope.variable_op_scope([x], name, 'convolution2d'): | ||||
|     num_input_channels = x.get_shape().dims[3].value | ||||
| 
 | ||||
|     if len(kernel_size) != 2: | ||||
|       raise ValueError('kernel_size must be length 2: %d ' % kernel_size) | ||||
|     if len(stride) != 2: | ||||
|       raise ValueError('stride must be length 2: %d' % stride) | ||||
| 
 | ||||
|     stride = [1, stride[0], stride[1], 1] | ||||
|     shape = [kernel_size[0], kernel_size[1], num_input_channels, | ||||
|              num_output_channels] | ||||
|     dtype = x.dtype.base_dtype | ||||
| 
 | ||||
|     weight_collections = set(list(weight_collections or []) + | ||||
|                              [ops.GraphKeys.VARIABLES]) | ||||
|     w = variable_scope.get_variable('weights', | ||||
|                                     shape=shape, | ||||
|                                     dtype=dtype, | ||||
|                                     initializer=weight_init, | ||||
|                                     collections=weight_collections, | ||||
|                                     regularizer=weight_regularizer, | ||||
|                                     trainable=trainable) | ||||
| 
 | ||||
|     y = nn.conv2d(x, w, stride, padding) | ||||
| 
 | ||||
|     if bias_init is not None: | ||||
|       bias_collections = set(list(bias_collections or []) + | ||||
|                              [ops.GraphKeys.VARIABLES]) | ||||
|       b = variable_scope.get_variable('bias', | ||||
|                                       shape=[num_output_channels], | ||||
|                                       dtype=dtype, | ||||
|                                       initializer=bias_init, | ||||
|                                       collections=bias_collections, | ||||
|                                       regularizer=bias_regularizer, | ||||
|                                       trainable=trainable) | ||||
| 
 | ||||
|       y = nn.bias_add(y, b) | ||||
| 
 | ||||
|     return _apply_activation(y, activation_fn, output_collections) | ||||
| 
 | ||||
| 
 | ||||
| # TODO(eiderm): Verify and fix autocomplete in colab (also relu6). | ||||
| # Simple aliases which remove the activation_fn parameter. | ||||
| legacy_relu = functools.partial(legacy_fully_connected, activation_fn=nn.relu) | ||||
| legacy_relu6 = functools.partial(legacy_fully_connected, activation_fn=nn.relu6) | ||||
| # Simple alias for fully_connected which removes the activation_fn parameter. | ||||
| legacy_linear = functools.partial(legacy_fully_connected, activation_fn=None) | ||||
| relu = functools.partial(fully_connected, activation_fn=nn.relu) | ||||
| relu6 = functools.partial(fully_connected, activation_fn=nn.relu6) | ||||
| linear = functools.partial(fully_connected, activation_fn=None) | ||||
| 
 | ||||
| # Simple alias for convolution2d. | ||||
| conv2d = convolution2d | ||||
|  | ||||
| @ -435,15 +435,16 @@ class FCTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def testCreateFC(self): | ||||
|     height, width = 3, 3 | ||||
|     with self.test_session(): | ||||
|       inputs = tf.random_uniform((5, height * width * 3), seed=1) | ||||
|       output = tf.contrib.layers.fully_connected(inputs, 32) | ||||
|       self.assertEquals(output.op.name, 'fully_connected/Relu') | ||||
|       self.assertListEqual(output.get_shape().as_list(), [5, 32]) | ||||
|       weights = tf.contrib.framework.get_variables_by_name('weights')[0] | ||||
|       self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32]) | ||||
|       biases = tf.contrib.framework.get_variables_by_name('biases')[0] | ||||
|       self.assertListEqual(biases.get_shape().as_list(), [32]) | ||||
|     for layer_fn in (tf.contrib.layers.fully_connected, tf.contrib.layers.relu): | ||||
|       with tf.Graph().as_default() as g, self.test_session(g): | ||||
|         inputs = tf.random_uniform((5, height * width * 3), seed=1) | ||||
|         output = layer_fn(inputs, 32) | ||||
|         self.assertEquals(output.op.name, 'fully_connected/Relu') | ||||
|         self.assertListEqual(output.get_shape().as_list(), [5, 32]) | ||||
|         weights = tf.contrib.framework.get_variables_by_name('weights')[0] | ||||
|         self.assertListEqual(weights.get_shape().as_list(), [3 * 3 * 3, 32]) | ||||
|         biases = tf.contrib.framework.get_variables_by_name('biases')[0] | ||||
|         self.assertListEqual(biases.get_shape().as_list(), [32]) | ||||
| 
 | ||||
|   def testCreateFCWithScope(self): | ||||
|     height, width = 3, 3 | ||||
| @ -985,27 +986,6 @@ class LegacyFullyConnectedTest(tf.test.TestCase): | ||||
|     self.assertEqual(0, | ||||
|                      len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))) | ||||
| 
 | ||||
|   def test_relu6_layer_basic_use(self): | ||||
|     output = tf.contrib.layers.legacy_relu6(self.input, 8) | ||||
| 
 | ||||
|     with tf.Session() as sess: | ||||
|       with self.assertRaises(tf.errors.FailedPreconditionError): | ||||
|         sess.run(output) | ||||
| 
 | ||||
|       tf.initialize_all_variables().run() | ||||
|       out_value = sess.run(output) | ||||
| 
 | ||||
|     self.assertEqual(output.get_shape().as_list(), [2, 8]) | ||||
|     self.assertTrue(np.all(out_value >= 0), | ||||
|                     'Relu6 should have all values >= 0.') | ||||
|     self.assertTrue(np.all(out_value <= 6), | ||||
|                     'Relu6 should have all values <= 6.') | ||||
| 
 | ||||
|     self.assertEqual(2, | ||||
|                      len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))) | ||||
|     self.assertEqual(0, | ||||
|                      len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))) | ||||
| 
 | ||||
|   def test_variable_reuse_with_scope(self): | ||||
|     with tf.variable_scope('test') as vs: | ||||
|       output1 = tf.contrib.layers.legacy_relu(self.input, 8) | ||||
| @ -1208,131 +1188,5 @@ class LegacyFullyConnectedTest(tf.test.TestCase): | ||||
|                                                  activation_fn=tf.nn.softmax) | ||||
| 
 | ||||
| 
 | ||||
| class LegacyConvolution2dTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def setUp(self): | ||||
|     tf.test.TestCase.setUp(self) | ||||
|     tf.set_random_seed(1234) | ||||
|     self.input = tf.constant(np.arange(2 * 3 * 3 * 4).reshape( | ||||
|         [2, 3, 3, 4]).astype(np.float32)) | ||||
|     assert not tf.get_collection(tf.GraphKeys.SUMMARIES) | ||||
| 
 | ||||
|   def test_basic_use(self): | ||||
|     output = tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                                     8, (3, 3), | ||||
|                                                     activation_fn=tf.nn.relu) | ||||
| 
 | ||||
|     with tf.Session() as sess: | ||||
|       with self.assertRaises(tf.errors.FailedPreconditionError): | ||||
|         sess.run(output) | ||||
| 
 | ||||
|       tf.initialize_all_variables().run() | ||||
|       out_value = sess.run(output) | ||||
| 
 | ||||
|     self.assertEqual(output.get_shape().as_list(), [2, 3, 3, 8]) | ||||
|     self.assertTrue(np.all(out_value >= 0), | ||||
|                     'Relu should have capped all values.') | ||||
|     self.assertEqual(2, | ||||
|                      len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))) | ||||
|     self.assertEqual(0, | ||||
|                      len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))) | ||||
| 
 | ||||
|   def test_variable_reuse_with_scope(self): | ||||
|     with tf.variable_scope('test') as vs: | ||||
|       output1 = tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                                        8, (3, 3), | ||||
|                                                        activation_fn=tf.nn.relu) | ||||
|       output2 = tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                                        8, (3, 3), | ||||
|                                                        activation_fn=tf.nn.relu) | ||||
| 
 | ||||
|     with tf.variable_scope(vs, reuse=True): | ||||
|       output3 = tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                                        8, (3, 3), | ||||
|                                                        activation_fn=tf.nn.relu) | ||||
| 
 | ||||
|     with tf.Session() as sess: | ||||
|       tf.initialize_all_variables().run() | ||||
|       out_value1, out_value2, out_value3 = sess.run([output1, output2, output3]) | ||||
| 
 | ||||
|     self.assertFalse(np.allclose(out_value1, out_value2)) | ||||
|     self.assertAllClose(out_value1, out_value3) | ||||
| 
 | ||||
|   def test_variable_reuse_with_template(self): | ||||
|     tmpl1 = tf.make_template('test', | ||||
|                              tf.contrib.layers.legacy_convolution2d, | ||||
|                              kernel_size=(3, 3), | ||||
|                              num_output_channels=8) | ||||
|     output1 = tmpl1(self.input) | ||||
|     output2 = tmpl1(self.input) | ||||
| 
 | ||||
|     with tf.Session() as sess: | ||||
|       tf.initialize_all_variables().run() | ||||
|       out_value1, out_value2 = sess.run([output1, output2]) | ||||
|     self.assertAllClose(out_value1, out_value2) | ||||
| 
 | ||||
|   def test_custom_initializers(self): | ||||
|     output = tf.contrib.layers.legacy_convolution2d( | ||||
|         self.input, | ||||
|         2, (3, 3), | ||||
|         activation_fn=tf.nn.relu, | ||||
|         weight_init=tf.constant_initializer(2.0), | ||||
|         bias_init=tf.constant_initializer(1.0), | ||||
|         padding='VALID') | ||||
| 
 | ||||
|     with tf.Session() as sess: | ||||
|       tf.initialize_all_variables().run() | ||||
|       out_value = sess.run(output) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.array([[[[1261., 1261.]]], [[[3853., 3853.]]]]), out_value) | ||||
| 
 | ||||
|   def test_custom_collections(self): | ||||
|     tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                            2, (3, 3), | ||||
|                                            activation_fn=tf.nn.relu, | ||||
|                                            weight_collections=['unbiased'], | ||||
|                                            bias_collections=['biased']) | ||||
| 
 | ||||
|     self.assertEquals(1, len(tf.get_collection('unbiased'))) | ||||
|     self.assertEquals(1, len(tf.get_collection('biased'))) | ||||
| 
 | ||||
|   def test_all_custom_collections(self): | ||||
|     tf.contrib.layers.legacy_convolution2d( | ||||
|         self.input, | ||||
|         2, (3, 3), | ||||
|         activation_fn=tf.nn.relu, | ||||
|         weight_collections=['unbiased', 'all'], | ||||
|         bias_collections=['biased', 'all']) | ||||
| 
 | ||||
|     self.assertEquals(1, len(tf.get_collection('unbiased'))) | ||||
|     self.assertEquals(1, len(tf.get_collection('biased'))) | ||||
|     self.assertEquals(2, len(tf.get_collection('all'))) | ||||
|     self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), | ||||
|                       tf.get_collection('all')) | ||||
| 
 | ||||
|   def test_regularizer(self): | ||||
|     cnt = [0] | ||||
|     tensor = tf.constant(5.0) | ||||
|     def test_fn(_): | ||||
|       cnt[0] += 1 | ||||
|       return tensor | ||||
| 
 | ||||
|     tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                            2, (3, 3), | ||||
|                                            weight_regularizer=test_fn) | ||||
| 
 | ||||
|     self.assertEqual([tensor], | ||||
|                      tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) | ||||
|     self.assertEqual(1, cnt[0]) | ||||
| 
 | ||||
|   def test_no_bias(self): | ||||
|     tf.contrib.layers.legacy_convolution2d(self.input, | ||||
|                                            2, (3, 3), | ||||
|                                            bias_init=None) | ||||
|     self.assertEqual(1, | ||||
|                      len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   tf.test.main() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user