Default reuse should be None, never False. Correct it in layers and add tests.
Change: 144903238
This commit is contained in:
parent
f0a1af4a8b
commit
1ed4e69f7b
@ -268,7 +268,7 @@ def conv1d(inputs,
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for 1D convolution layer (e.g. temporal convolution).
|
"""Functional interface for 1D convolution layer (e.g. temporal convolution).
|
||||||
|
|
||||||
This layer creates a convolution kernel that is convolved
|
This layer creates a convolution kernel that is convolved
|
||||||
@ -435,7 +435,7 @@ def conv2d(inputs,
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for the 2D convolution layer.
|
"""Functional interface for the 2D convolution layer.
|
||||||
|
|
||||||
This layer creates a convolution kernel that is convolved
|
This layer creates a convolution kernel that is convolved
|
||||||
@ -608,7 +608,7 @@ def conv3d(inputs,
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for the 3D convolution layer.
|
"""Functional interface for the 3D convolution layer.
|
||||||
|
|
||||||
This layer creates a convolution kernel that is convolved
|
This layer creates a convolution kernel that is convolved
|
||||||
@ -867,7 +867,7 @@ def separable_conv2d(inputs,
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for the depthwise separable 2D convolution layer.
|
"""Functional interface for the depthwise separable 2D convolution layer.
|
||||||
|
|
||||||
This layer performs a depthwise convolution that acts separately on
|
This layer performs a depthwise convolution that acts separately on
|
||||||
@ -1128,7 +1128,7 @@ def conv2d_transpose(inputs,
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Transposed convolution layer (sometimes called Deconvolution).
|
"""Transposed convolution layer (sometimes called Deconvolution).
|
||||||
|
|
||||||
The need for transposed convolutions generally arises
|
The need for transposed convolutions generally arises
|
||||||
|
@ -18,11 +18,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.layers import convolutional as conv_layers
|
from tensorflow.python.layers import convolutional as conv_layers
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -191,21 +196,45 @@ class ConvTest(test.TestCase):
|
|||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.conv2d(images, 32, [3, 3], name='conv1')
|
conv_layers.conv2d(images, 32, [3, 3], name='conv1')
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
conv_layers.conv2d(images, 32, [3, 3], name='conv1', reuse=True)
|
conv_layers.conv2d(images, 32, [3, 3], name='conv1', reuse=True)
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
|
def testFunctionalConv2DReuseFromScope(self):
|
||||||
|
with variable_scope.variable_scope('scope'):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.conv2d(images, 32, [3, 3], name='conv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
|
conv_layers.conv2d(images, 32, [3, 3], name='conv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
|
|
||||||
|
def testFunctionalConv2DInitializerFromScope(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
'scope', initializer=init_ops.ones_initializer()):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.conv2d(images, 32, [3, 3], name='conv1')
|
||||||
|
weights = variables.trainable_variables()
|
||||||
|
# Check the names of weights in order.
|
||||||
|
self.assertTrue('kernel' in weights[0].name)
|
||||||
|
self.assertTrue('bias' in weights[1].name)
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
weights = sess.run(weights)
|
||||||
|
# Check that the kernel weights got initialized to ones (from scope)
|
||||||
|
self.assertAllClose(weights[0], np.ones((3, 3, 3, 32)))
|
||||||
|
# Check that the bias still got initialized to zeros.
|
||||||
|
self.assertAllClose(weights[1], np.zeros((32)))
|
||||||
|
|
||||||
def testFunctionalConv2DNoReuse(self):
|
def testFunctionalConv2DNoReuse(self):
|
||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.conv2d(images, 32, [3, 3])
|
conv_layers.conv2d(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
conv_layers.conv2d(images, 32, [3, 3])
|
conv_layers.conv2d(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 4)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 4)
|
|
||||||
|
|
||||||
|
|
||||||
class SeparableConv2DTest(test.TestCase):
|
class SeparableConv2DTest(test.TestCase):
|
||||||
@ -323,22 +352,48 @@ class SeparableConv2DTest(test.TestCase):
|
|||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1')
|
conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1')
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 3)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3)
|
|
||||||
conv_layers.separable_conv2d(
|
conv_layers.separable_conv2d(
|
||||||
images, 32, [3, 3], name='sepconv1', reuse=True)
|
images, 32, [3, 3], name='sepconv1', reuse=True)
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 3)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3)
|
|
||||||
|
def testFunctionalConv2DReuseFromScope(self):
|
||||||
|
with variable_scope.variable_scope('scope'):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 3)
|
||||||
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
|
conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 3)
|
||||||
|
|
||||||
|
def testFunctionalConv2DInitializerFromScope(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
'scope', initializer=init_ops.ones_initializer()):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.separable_conv2d(images, 32, [3, 3], name='sepconv1')
|
||||||
|
weights = variables.trainable_variables()
|
||||||
|
# Check the names of weights in order.
|
||||||
|
self.assertTrue('depthwise_kernel' in weights[0].name)
|
||||||
|
self.assertTrue('pointwise_kernel' in weights[1].name)
|
||||||
|
self.assertTrue('bias' in weights[2].name)
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
weights = sess.run(weights)
|
||||||
|
# Check that the kernel weights got initialized to ones (from scope)
|
||||||
|
self.assertAllClose(weights[0], np.ones((3, 3, 3, 1)))
|
||||||
|
self.assertAllClose(weights[1], np.ones((1, 1, 3, 32)))
|
||||||
|
# Check that the bias still got initialized to zeros.
|
||||||
|
self.assertAllClose(weights[2], np.zeros((32)))
|
||||||
|
|
||||||
def testFunctionalConv2DNoReuse(self):
|
def testFunctionalConv2DNoReuse(self):
|
||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.separable_conv2d(images, 32, [3, 3])
|
conv_layers.separable_conv2d(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 3)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 3)
|
|
||||||
conv_layers.separable_conv2d(images, 32, [3, 3])
|
conv_layers.separable_conv2d(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 6)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 6)
|
|
||||||
|
|
||||||
def testSeparableConv2DDepthwiseRegularizer(self):
|
def testSeparableConv2DDepthwiseRegularizer(self):
|
||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
@ -511,21 +566,45 @@ class Conv2DTransposeTest(test.TestCase):
|
|||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1')
|
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1')
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1', reuse=True)
|
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1', reuse=True)
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
|
def testFunctionalConv2DTransposeReuseFromScope(self):
|
||||||
|
with variable_scope.variable_scope('scope'):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
|
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1')
|
||||||
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
|
|
||||||
|
def testFunctionalConv2DTransposeInitializerFromScope(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
'scope', initializer=init_ops.ones_initializer()):
|
||||||
|
height, width = 7, 9
|
||||||
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
|
conv_layers.conv2d_transpose(images, 32, [3, 3], name='deconv1')
|
||||||
|
weights = variables.trainable_variables()
|
||||||
|
# Check the names of weights in order.
|
||||||
|
self.assertTrue('kernel' in weights[0].name)
|
||||||
|
self.assertTrue('bias' in weights[1].name)
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
weights = sess.run(weights)
|
||||||
|
# Check that the kernel weights got initialized to ones (from scope)
|
||||||
|
self.assertAllClose(weights[0], np.ones((3, 3, 32, 3)))
|
||||||
|
# Check that the bias still got initialized to zeros.
|
||||||
|
self.assertAllClose(weights[1], np.zeros((32)))
|
||||||
|
|
||||||
def testFunctionalConv2DTransposeNoReuse(self):
|
def testFunctionalConv2DTransposeNoReuse(self):
|
||||||
height, width = 7, 9
|
height, width = 7, 9
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
conv_layers.conv2d_transpose(images, 32, [3, 3])
|
conv_layers.conv2d_transpose(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
|
||||||
conv_layers.conv2d_transpose(images, 32, [3, 3])
|
conv_layers.conv2d_transpose(images, 32, [3, 3])
|
||||||
self.assertEqual(
|
self.assertEqual(len(variables.trainable_variables()), 4)
|
||||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -167,7 +167,7 @@ def dense(
|
|||||||
activity_regularizer=None,
|
activity_regularizer=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for the densely-connected layer.
|
"""Functional interface for the densely-connected layer.
|
||||||
|
|
||||||
This layer implements the operation:
|
This layer implements the operation:
|
||||||
|
@ -207,6 +207,16 @@ class DenseTest(test.TestCase):
|
|||||||
vars2 = variables.trainable_variables()
|
vars2 = variables.trainable_variables()
|
||||||
self.assertEqual(vars1, vars2)
|
self.assertEqual(vars1, vars2)
|
||||||
|
|
||||||
|
def testFunctionalDenseTwiceReuseFromScope(self):
|
||||||
|
with variable_scope.variable_scope('scope'):
|
||||||
|
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||||
|
core_layers.dense(inputs, 2, name='my_dense')
|
||||||
|
vars1 = variables.trainable_variables()
|
||||||
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
|
core_layers.dense(inputs, 2, name='my_dense')
|
||||||
|
vars2 = variables.trainable_variables()
|
||||||
|
self.assertEqual(vars1, vars2)
|
||||||
|
|
||||||
def testFunctionalDenseInitializerFromScope(self):
|
def testFunctionalDenseInitializerFromScope(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
|
@ -257,7 +257,7 @@ def batch_normalization(inputs,
|
|||||||
training=False,
|
training=False,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
reuse=False):
|
reuse=None):
|
||||||
"""Functional interface for the batch normalization layer.
|
"""Functional interface for the batch normalization layer.
|
||||||
|
|
||||||
Reference: http://arxiv.org/abs/1502.03167
|
Reference: http://arxiv.org/abs/1502.03167
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.layers import normalization as normalization_layers
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -454,6 +455,20 @@ class BNTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2)
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2)
|
||||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
def testFunctionalReuseFromScope(self):
|
||||||
|
inputs = variables.Variable(
|
||||||
|
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
|
||||||
|
epsilon = 1e-3
|
||||||
|
training = array_ops.placeholder(dtype='bool')
|
||||||
|
with variable_scope.variable_scope('scope'):
|
||||||
|
_ = normalization_layers.batch_norm(
|
||||||
|
inputs, axis=-1, momentum=0.9, epsilon=epsilon, training=training)
|
||||||
|
self.assertEqual(len(variables.global_variables()), 5)
|
||||||
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
|
_ = normalization_layers.batch_norm(
|
||||||
|
inputs, axis=-1, momentum=0.9, epsilon=epsilon, training=training)
|
||||||
|
self.assertEqual(len(variables.global_variables()), 5)
|
||||||
|
|
||||||
def testNoCenter(self):
|
def testNoCenter(self):
|
||||||
bn = normalization_layers.BatchNormalization(axis=1, center=False)
|
bn = normalization_layers.BatchNormalization(axis=1, center=False)
|
||||||
inputs = random_ops.random_uniform((5, 4, 3), seed=1)
|
inputs = random_ops.random_uniform((5, 4, 3), seed=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user