diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index a476b0f72a3..3b96d4362fd 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -268,7 +268,7 @@ def conv1d(inputs, activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for 1D convolution layer (e.g. temporal convolution). This layer creates a convolution kernel that is convolved @@ -435,7 +435,7 @@ def conv2d(inputs, activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for the 2D convolution layer. This layer creates a convolution kernel that is convolved @@ -608,7 +608,7 @@ def conv3d(inputs, activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for the 3D convolution layer. This layer creates a convolution kernel that is convolved @@ -867,7 +867,7 @@ def separable_conv2d(inputs, activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for the depthwise separable 2D convolution layer. This layer performs a depthwise convolution that acts separately on @@ -1128,7 +1128,7 @@ def conv2d_transpose(inputs, activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Transposed convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index c47e92c5824..1a5fe5c9b7d 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional as conv_layers +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -191,21 +196,45 @@ class ConvTest(test.TestCase): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.conv2d(images, 32, [3, 3], name='conv1') - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) conv_layers.conv2d(images, 32, [3, 3], name='conv1', reuse=True) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv2DReuseFromScope(self): + with variable_scope.variable_scope('scope'): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.conv2d(images, 32, [3, 3], name='conv1') + self.assertEqual(len(variables.trainable_variables()), 2) + with variable_scope.variable_scope('scope', reuse=True): + conv_layers.conv2d(images, 32, [3, 3], name='conv1') + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv2DInitializerFromScope(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + 'scope', initializer=init_ops.ones_initializer()): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.conv2d(images, 32, [3, 3], name='conv1') + weights = variables.trainable_variables() + # Check the names of weights in order. + self.assertTrue('kernel' in weights[0].name) + self.assertTrue('bias' in weights[1].name) + sess.run(variables.global_variables_initializer()) + weights = sess.run(weights) + # Check that the kernel weights got initialized to ones (from scope) + self.assertAllClose(weights[0], np.ones((3, 3, 3, 32))) + # Check that the bias still got initialized to zeros. + self.assertAllClose(weights[1], np.zeros((32))) def testFunctionalConv2DNoReuse(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.conv2d(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) conv_layers.conv2d(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 4) + self.assertEqual(len(variables.trainable_variables()), 4) class SeparableConv2DTest(test.TestCase): @@ -323,22 +352,48 @@ class SeparableConv2DTest(test.TestCase): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1') - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3) + self.assertEqual(len(variables.trainable_variables()), 3) conv_layers.separable_conv2d( images, 32, [3, 3], name='sepconv1', reuse=True) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3) + self.assertEqual(len(variables.trainable_variables()), 3) + + def testFunctionalConv2DReuseFromScope(self): + with variable_scope.variable_scope('scope'): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1') + self.assertEqual(len(variables.trainable_variables()), 3) + with variable_scope.variable_scope('scope', reuse=True): + conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1') + self.assertEqual(len(variables.trainable_variables()), 3) + + def testFunctionalConv2DInitializerFromScope(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + 'scope', initializer=init_ops.ones_initializer()): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1') + weights = variables.trainable_variables() + # Check the names of weights in order. + self.assertTrue('depthwise_kernel' in weights[0].name) + self.assertTrue('pointwise_kernel' in weights[1].name) + self.assertTrue('bias' in weights[2].name) + sess.run(variables.global_variables_initializer()) + weights = sess.run(weights) + # Check that the kernel weights got initialized to ones (from scope) + self.assertAllClose(weights[0], np.ones((3, 3, 3, 1))) + self.assertAllClose(weights[1], np.ones((1, 1, 3, 32))) + # Check that the bias still got initialized to zeros. + self.assertAllClose(weights[2], np.zeros((32))) def testFunctionalConv2DNoReuse(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.separable_conv2d(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3) + self.assertEqual(len(variables.trainable_variables()), 3) conv_layers.separable_conv2d(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 6) + self.assertEqual(len(variables.trainable_variables()), 6) def testSeparableConv2DDepthwiseRegularizer(self): height, width = 7, 9 @@ -511,21 +566,45 @@ class Conv2DTransposeTest(test.TestCase): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1') - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1', reuse=True) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv2DTransposeReuseFromScope(self): + with variable_scope.variable_scope('scope'): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1') + self.assertEqual(len(variables.trainable_variables()), 2) + with variable_scope.variable_scope('scope', reuse=True): + conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1') + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv2DTransposeInitializerFromScope(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + 'scope', initializer=init_ops.ones_initializer()): + height, width = 7, 9 + images = random_ops.random_uniform((5, height, width, 3), seed=1) + conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1') + weights = variables.trainable_variables() + # Check the names of weights in order. + self.assertTrue('kernel' in weights[0].name) + self.assertTrue('bias' in weights[1].name) + sess.run(variables.global_variables_initializer()) + weights = sess.run(weights) + # Check that the kernel weights got initialized to ones (from scope) + self.assertAllClose(weights[0], np.ones((3, 3, 32, 3))) + # Check that the bias still got initialized to zeros. + self.assertAllClose(weights[1], np.zeros((32))) def testFunctionalConv2DTransposeNoReuse(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) conv_layers.conv2d_transpose(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) + self.assertEqual(len(variables.trainable_variables()), 2) conv_layers.conv2d_transpose(images, 32, [3, 3]) - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 4) + self.assertEqual(len(variables.trainable_variables()), 4) if __name__ == '__main__': diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index b8b6cd97da7..92894e14472 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -167,7 +167,7 @@ def dense( activity_regularizer=None, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for the densely-connected layer. This layer implements the operation: diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index 50e3ebc72f5..cfcee7b788f 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -207,6 +207,16 @@ class DenseTest(test.TestCase): vars2 = variables.trainable_variables() self.assertEqual(vars1, vars2) + def testFunctionalDenseTwiceReuseFromScope(self): + with variable_scope.variable_scope('scope'): + inputs = random_ops.random_uniform((5, 3), seed=1) + core_layers.dense(inputs, 2, name='my_dense') + vars1 = variables.trainable_variables() + with variable_scope.variable_scope('scope', reuse=True): + core_layers.dense(inputs, 2, name='my_dense') + vars2 = variables.trainable_variables() + self.assertEqual(vars1, vars2) + def testFunctionalDenseInitializerFromScope(self): with self.test_session() as sess: with variable_scope.variable_scope( diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index fcbc69f2c52..4a59d779483 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -257,7 +257,7 @@ def batch_normalization(inputs, training=False, trainable=True, name=None, - reuse=False): + reuse=None): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 93efc09ca06..91b7cb6f483 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -26,6 +26,7 @@ from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -454,6 +455,20 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def testFunctionalReuseFromScope(self): + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)), dtype=dtypes.float32) + epsilon = 1e-3 + training = array_ops.placeholder(dtype='bool') + with variable_scope.variable_scope('scope'): + _ = normalization_layers.batch_norm( + inputs, axis=-1, momentum=0.9, epsilon=epsilon, training=training) + self.assertEqual(len(variables.global_variables()), 5) + with variable_scope.variable_scope('scope', reuse=True): + _ = normalization_layers.batch_norm( + inputs, axis=-1, momentum=0.9, epsilon=epsilon, training=training) + self.assertEqual(len(variables.global_variables()), 5) + def testNoCenter(self): bn = normalization_layers.BatchNormalization(axis=1, center=False) inputs = random_ops.random_uniform((5, 4, 3), seed=1)