Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible).

PiperOrigin-RevId: 159649743
This commit is contained in:
Yao Zhang 2017-06-20 20:01:39 -07:00 committed by TensorFlower Gardener
parent a4a4698323
commit b3f33ad466
7 changed files with 135 additions and 99 deletions

View File

@ -7,6 +7,7 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"]) package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_test")
py_library( py_library(
@ -393,12 +394,11 @@ py_test(
], ],
) )
py_test( cuda_py_test(
name = "normalization_test", name = "normalization_test",
size = "small", size = "small",
srcs = ["python/keras/layers/normalization_test.py"], srcs = ["python/keras/layers/normalization_test.py"],
srcs_version = "PY2AND3", additional_deps = [
deps = [
":keras", ":keras",
":testing_utils", ":testing_utils",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -94,7 +94,8 @@ class NoiseLayersTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
def test_batchnorm_convnet(self): def test_batchnorm_convnet(self):
with self.test_session(): if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
model = keras.models.Sequential() model = keras.models.Sequential()
norm = keras.layers.BatchNormalization( norm = keras.layers.BatchNormalization(
axis=1, input_shape=(3, 4, 4), momentum=0.8) axis=1, input_shape=(3, 4, 4), momentum=0.8)

View File

@ -257,6 +257,7 @@ def _fused_batch_norm(
'beta') 'beta')
if not param_initializers: if not param_initializers:
param_initializers = {} param_initializers = {}
if center:
beta_initializer = param_initializers.get('beta', beta_initializer = param_initializers.get('beta',
init_ops.zeros_initializer()) init_ops.zeros_initializer())
beta = variables.model_variable( beta = variables.model_variable(
@ -266,9 +267,12 @@ def _fused_batch_norm(
initializer=beta_initializer, initializer=beta_initializer,
collections=beta_collections, collections=beta_collections,
trainable=trainable_beta) trainable=trainable_beta)
trainable_gamma = trainable and scale else:
gamma_collections = utils.get_variable_collections(variables_collections, beta = array_ops.constant(0.0, shape=params_shape)
'gamma')
if scale:
gamma_collections = utils.get_variable_collections(
variables_collections, 'gamma')
gamma_initializer = param_initializers.get('gamma', gamma_initializer = param_initializers.get('gamma',
init_ops.ones_initializer()) init_ops.ones_initializer())
gamma = variables.model_variable( gamma = variables.model_variable(
@ -277,7 +281,9 @@ def _fused_batch_norm(
dtype=dtype, dtype=dtype,
initializer=gamma_initializer, initializer=gamma_initializer,
collections=gamma_collections, collections=gamma_collections,
trainable=trainable_gamma) trainable=trainable)
else:
gamma = array_ops.constant(1.0, shape=params_shape)
# Create moving_mean and moving_variance variables and add them to the # Create moving_mean and moving_variance variables and add them to the
# appropriate collections. # appropriate collections.
@ -449,7 +455,8 @@ def batch_norm(inputs,
then the batch normalization uses weighted mean and then the batch normalization uses weighted mean and
variance. (This can be used to correct for bias in training variance. (This can be used to correct for bias in training
example selection.) example selection.)
fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. fused: if `True`, use a faster, fused implementation based on
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
data_format: A string. `NHWC` (default) and `NCHW` are supported. data_format: A string. `NHWC` (default) and `NCHW` are supported.
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
@ -473,7 +480,6 @@ def batch_norm(inputs,
Raises: Raises:
ValueError: If `batch_weights` is not None and `fused` is True. ValueError: If `batch_weights` is not None and `fused` is True.
ValueError: If `param_regularizers` is not None and `fused` is True.
ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If the rank of `inputs` is undefined. ValueError: If the rank of `inputs` is undefined.
ValueError: If rank or channels dimension of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined.
@ -487,6 +493,21 @@ def batch_norm(inputs,
'supported for fused batch norm.') 'supported for fused batch norm.')
if renorm: if renorm:
raise ValueError('Renorm is not supported for fused batch norm.') raise ValueError('Renorm is not supported for fused batch norm.')
# Only use _fused_batch_norm (1) if fused is set True or if it is
# possible to use (currently it doesn't support batch weights,
# renorm, and the case when rank is neither 2 nor 4),
# and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
# or non-default updates_collections (not implemented in
# normalization_layers.BatchNormalization yet); otherwise use the fused
# implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
feature_supported = batch_weights is None and not renorm and rank in [2, 4]
possible_to_fuse = fused is None and feature_supported
if (fused or possible_to_fuse) and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
return _fused_batch_norm( return _fused_batch_norm(
inputs, inputs,
decay=decay, decay=decay,
@ -552,7 +573,8 @@ def batch_norm(inputs,
renorm_momentum=renorm_decay, renorm_momentum=renorm_decay,
name=sc.name, name=sc.name,
_scope=sc, _scope=sc,
_reuse=reuse) _reuse=reuse,
fused=fused)
outputs = layer.apply(inputs, training=is_training) outputs = layer.apply(inputs, training=is_training)
# Add variables to collections. # Add variables to collections.

View File

@ -1703,13 +1703,6 @@ class BatchNormTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'): with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'):
_layers.batch_norm(inputs, batch_weights=batch_weights, fused=True) _layers.batch_norm(inputs, batch_weights=batch_weights, fused=True)
def testParamRegularizersFused(self):
with ops.Graph().as_default() as g, self.test_session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7))
with self.assertRaisesRegexp(ValueError,
'Regularizers are not currently'):
_layers.batch_norm(inputs, param_regularizers={}, fused=True)
def _testCreateOp(self, fused): def _testCreateOp(self, fused):
height, width = 3, 3 height, width = 3, 3
with self.test_session(): with self.test_session():
@ -1780,7 +1773,8 @@ class BatchNormTest(test.TestCase):
height, width = 3, 3 height, width = 3, 3
with self.test_session(): with self.test_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1) images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True, zero_debias_moving_mean=True) _layers.batch_norm(
images, scale=True, zero_debias_moving_mean=True, fused=False)
self.assertEqual(len(variables.get_model_variables()), 6) self.assertEqual(len(variables.get_model_variables()), 6)
moving_mean = variables.get_variables_by_name('moving_mean')[0] moving_mean = variables.get_variables_by_name('moving_mean')[0]
moving_variance = variables.get_variables_by_name('moving_variance')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0]
@ -1874,7 +1868,8 @@ class BatchNormTest(test.TestCase):
images, images,
decay=0.1, decay=0.1,
updates_collections=None, updates_collections=None,
zero_debias_moving_mean=True) zero_debias_moving_mean=True,
fused=False)
moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0] moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0]
moving_variance = variables.get_variables_by_name('moving_variance')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0]
biased = variables.get_variables_by_name('biased')[0] biased = variables.get_variables_by_name('biased')[0]
@ -2523,7 +2518,7 @@ class BatchNormTest(test.TestCase):
def _runBatchNormalizationWithFormat(self, shape, data_format, is_training): def _runBatchNormalizationWithFormat(self, shape, data_format, is_training):
channels = shape[-1] channels = shape[-1]
with self.test_session() as sess: with self.test_session(use_gpu=True) as sess:
images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) images = np.arange(np.product(shape), dtype=np.float32).reshape(shape)
beta = init_ops.constant_initializer( beta = init_ops.constant_initializer(
np.arange( np.arange(
@ -2561,6 +2556,7 @@ class BatchNormTest(test.TestCase):
return sess.run(output) return sess.run(output)
def testNHWCAndNCHWInferenceProduceSameOutput(self): def testNHWCAndNCHWInferenceProduceSameOutput(self):
if test.is_gpu_available(cuda_only=True):
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
nhwc = self._runBatchNormalizationWithFormat( nhwc = self._runBatchNormalizationWithFormat(
data_format='NHWC', shape=shape, is_training=False) data_format='NHWC', shape=shape, is_training=False)
@ -2569,6 +2565,7 @@ class BatchNormTest(test.TestCase):
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
def testNHWCAndNCHWTrainingProduceSameOutput(self): def testNHWCAndNCHWTrainingProduceSameOutput(self):
if test.is_gpu_available(cuda_only=True):
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
nhwc = self._runBatchNormalizationWithFormat( nhwc = self._runBatchNormalizationWithFormat(
data_format='NHWC', shape=shape, is_training=True) data_format='NHWC', shape=shape, is_training=True)

View File

@ -220,7 +220,7 @@ def LogisticClassifier(inputs):
def BatchNormClassifier(inputs): def BatchNormClassifier(inputs):
inputs = layers.batch_norm(inputs, decay=0.1) inputs = layers.batch_norm(inputs, decay=0.1, fused=None)
return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
@ -267,6 +267,11 @@ class CreateTrainOpTest(test.TestCase):
self._inputs = np.random.rand(16, 4).astype(np.float32) self._inputs = np.random.rand(16, 4).astype(np.float32)
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
def _addBesselsCorrection(self, sample_size, expected_var):
correction_factor = sample_size / (sample_size - 1)
expected_var *= correction_factor
return expected_var
def testUseUpdateOps(self): def testUseUpdateOps(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
@ -275,6 +280,7 @@ class CreateTrainOpTest(test.TestCase):
expected_mean = np.mean(self._inputs, axis=(0)) expected_mean = np.mean(self._inputs, axis=(0))
expected_var = np.var(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0))
expected_var = self._addBesselsCorrection(16, expected_var)
tf_predictions = BatchNormClassifier(tf_inputs) tf_predictions = BatchNormClassifier(tf_inputs)
loss_ops.log_loss(tf_predictions, tf_labels) loss_ops.log_loss(tf_predictions, tf_labels)

View File

@ -123,6 +123,10 @@ class BatchNormalization(base.Layer):
if self.fused and renorm: if self.fused and renorm:
raise ValueError( raise ValueError(
'Batch renorm is currently not supported with fused batch norm.') 'Batch renorm is currently not supported with fused batch norm.')
if self.fused and (beta_regularizer is not None or
gamma_regularizer is not None):
raise ValueError('Regularizers are not currently '
'supported for fused batch norm.')
if renorm: if renorm:
renorm_clipping = renorm_clipping or {} renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax'] keys = ['rmax', 'rmin', 'dmax']
@ -153,7 +157,12 @@ class BatchNormalization(base.Layer):
' is out of range for input with rank ' + str(ndim)) ' is out of range for input with rank ' + str(ndim))
if self.fused is None: if self.fused is None:
self.fused = not self.renorm and ndim == 4 and axis in [1, 3] # Currently fused batch norm doesn't support renorm and beta/gamma
# regularizer; and only supports an input tensor of rank 4 and a channel
# dimension on axis 1 and 3.
self.fused = not self.renorm and ndim == 4 and axis in [
1, 3
] and self.beta_regularizer is None and self.gamma_regularizer is None
if self.fused: if self.fused:
if axis == 1: if axis == 1:

View File

@ -143,6 +143,7 @@ class BNTest(test.TestCase):
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def test4DInputAxis1(self): def test4DInputAxis1(self):
if test.is_gpu_available(cuda_only=True):
epsilon = 1e-3 epsilon = 1e-3
bn = normalization_layers.BatchNormalization( bn = normalization_layers.BatchNormalization(
axis=1, epsilon=epsilon, momentum=0.9) axis=1, epsilon=epsilon, momentum=0.9)
@ -151,15 +152,15 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool') training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training) outputs = bn.apply(inputs, training=training)
with self.test_session() as sess: with self.test_session(use_gpu=True) as sess:
# Test training with placeholder learning phase. # Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1)) np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100): for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates, np_output, _, _ = sess.run(
feed_dict={training: True}) [outputs] + bn.updates, feed_dict={training: True})
# Verify that the axis is normalized during training. # Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)