Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible).
PiperOrigin-RevId: 159649743
This commit is contained in:
parent
a4a4698323
commit
b3f33ad466
@ -7,6 +7,7 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
@ -393,12 +394,11 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "normalization_test",
|
||||
size = "small",
|
||||
srcs = ["python/keras/layers/normalization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":keras",
|
||||
":testing_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -94,22 +94,23 @@ class NoiseLayersTest(test.TestCase):
|
||||
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||
|
||||
def test_batchnorm_convnet(self):
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
norm = keras.layers.BatchNormalization(
|
||||
axis=1, input_shape=(3, 4, 4), momentum=0.8)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True):
|
||||
model = keras.models.Sequential()
|
||||
norm = keras.layers.BatchNormalization(
|
||||
axis=1, input_shape=(3, 4, 4), momentum=0.8)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
|
||||
model.fit(x, x, epochs=4, verbose=0)
|
||||
out = model.predict(x)
|
||||
out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
|
||||
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
|
||||
# centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
|
||||
model.fit(x, x, epochs=4, verbose=0)
|
||||
out = model.predict(x)
|
||||
out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
|
||||
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
|
||||
|
||||
np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
|
||||
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
|
||||
np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
|
||||
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
|
||||
|
||||
def test_shared_batchnorm(self):
|
||||
"""Test that a BN layer can be shared across different data streams.
|
||||
|
@ -257,27 +257,33 @@ def _fused_batch_norm(
|
||||
'beta')
|
||||
if not param_initializers:
|
||||
param_initializers = {}
|
||||
beta_initializer = param_initializers.get('beta',
|
||||
init_ops.zeros_initializer())
|
||||
beta = variables.model_variable(
|
||||
'beta',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=beta_initializer,
|
||||
collections=beta_collections,
|
||||
trainable=trainable_beta)
|
||||
trainable_gamma = trainable and scale
|
||||
gamma_collections = utils.get_variable_collections(variables_collections,
|
||||
'gamma')
|
||||
gamma_initializer = param_initializers.get('gamma',
|
||||
init_ops.ones_initializer())
|
||||
gamma = variables.model_variable(
|
||||
'gamma',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=gamma_initializer,
|
||||
collections=gamma_collections,
|
||||
trainable=trainable_gamma)
|
||||
if center:
|
||||
beta_initializer = param_initializers.get('beta',
|
||||
init_ops.zeros_initializer())
|
||||
beta = variables.model_variable(
|
||||
'beta',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=beta_initializer,
|
||||
collections=beta_collections,
|
||||
trainable=trainable_beta)
|
||||
else:
|
||||
beta = array_ops.constant(0.0, shape=params_shape)
|
||||
|
||||
if scale:
|
||||
gamma_collections = utils.get_variable_collections(
|
||||
variables_collections, 'gamma')
|
||||
gamma_initializer = param_initializers.get('gamma',
|
||||
init_ops.ones_initializer())
|
||||
gamma = variables.model_variable(
|
||||
'gamma',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=gamma_initializer,
|
||||
collections=gamma_collections,
|
||||
trainable=trainable)
|
||||
else:
|
||||
gamma = array_ops.constant(1.0, shape=params_shape)
|
||||
|
||||
# Create moving_mean and moving_variance variables and add them to the
|
||||
# appropriate collections.
|
||||
@ -449,7 +455,8 @@ def batch_norm(inputs,
|
||||
then the batch normalization uses weighted mean and
|
||||
variance. (This can be used to correct for bias in training
|
||||
example selection.)
|
||||
fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
|
||||
fused: if `True`, use a faster, fused implementation based on
|
||||
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||
data_format: A string. `NHWC` (default) and `NCHW` are supported.
|
||||
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
|
||||
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
|
||||
@ -473,7 +480,6 @@ def batch_norm(inputs,
|
||||
|
||||
Raises:
|
||||
ValueError: If `batch_weights` is not None and `fused` is True.
|
||||
ValueError: If `param_regularizers` is not None and `fused` is True.
|
||||
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
|
||||
ValueError: If the rank of `inputs` is undefined.
|
||||
ValueError: If rank or channels dimension of `inputs` is undefined.
|
||||
@ -487,6 +493,21 @@ def batch_norm(inputs,
|
||||
'supported for fused batch norm.')
|
||||
if renorm:
|
||||
raise ValueError('Renorm is not supported for fused batch norm.')
|
||||
|
||||
# Only use _fused_batch_norm (1) if fused is set True or if it is
|
||||
# possible to use (currently it doesn't support batch weights,
|
||||
# renorm, and the case when rank is neither 2 nor 4),
|
||||
# and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
|
||||
# or non-default updates_collections (not implemented in
|
||||
# normalization_layers.BatchNormalization yet); otherwise use the fused
|
||||
# implementation in normalization_layers.BatchNormalization.
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
rank = inputs.get_shape().ndims
|
||||
feature_supported = batch_weights is None and not renorm and rank in [2, 4]
|
||||
possible_to_fuse = fused is None and feature_supported
|
||||
if (fused or possible_to_fuse) and (
|
||||
zero_debias_moving_mean or rank == 2 or
|
||||
updates_collections is not ops.GraphKeys.UPDATE_OPS):
|
||||
return _fused_batch_norm(
|
||||
inputs,
|
||||
decay=decay,
|
||||
@ -552,7 +573,8 @@ def batch_norm(inputs,
|
||||
renorm_momentum=renorm_decay,
|
||||
name=sc.name,
|
||||
_scope=sc,
|
||||
_reuse=reuse)
|
||||
_reuse=reuse,
|
||||
fused=fused)
|
||||
outputs = layer.apply(inputs, training=is_training)
|
||||
|
||||
# Add variables to collections.
|
||||
|
@ -1703,13 +1703,6 @@ class BatchNormTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'):
|
||||
_layers.batch_norm(inputs, batch_weights=batch_weights, fused=True)
|
||||
|
||||
def testParamRegularizersFused(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Regularizers are not currently'):
|
||||
_layers.batch_norm(inputs, param_regularizers={}, fused=True)
|
||||
|
||||
def _testCreateOp(self, fused):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
@ -1780,7 +1773,8 @@ class BatchNormTest(test.TestCase):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||
_layers.batch_norm(images, scale=True, zero_debias_moving_mean=True)
|
||||
_layers.batch_norm(
|
||||
images, scale=True, zero_debias_moving_mean=True, fused=False)
|
||||
self.assertEqual(len(variables.get_model_variables()), 6)
|
||||
moving_mean = variables.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
||||
@ -1874,7 +1868,8 @@ class BatchNormTest(test.TestCase):
|
||||
images,
|
||||
decay=0.1,
|
||||
updates_collections=None,
|
||||
zero_debias_moving_mean=True)
|
||||
zero_debias_moving_mean=True,
|
||||
fused=False)
|
||||
moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0]
|
||||
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
||||
biased = variables.get_variables_by_name('biased')[0]
|
||||
@ -2523,7 +2518,7 @@ class BatchNormTest(test.TestCase):
|
||||
|
||||
def _runBatchNormalizationWithFormat(self, shape, data_format, is_training):
|
||||
channels = shape[-1]
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
images = np.arange(np.product(shape), dtype=np.float32).reshape(shape)
|
||||
beta = init_ops.constant_initializer(
|
||||
np.arange(
|
||||
@ -2561,20 +2556,22 @@ class BatchNormTest(test.TestCase):
|
||||
return sess.run(output)
|
||||
|
||||
def testNHWCAndNCHWInferenceProduceSameOutput(self):
|
||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||
nhwc = self._runBatchNormalizationWithFormat(
|
||||
data_format='NHWC', shape=shape, is_training=False)
|
||||
nchw = self._runBatchNormalizationWithFormat(
|
||||
data_format='NCHW', shape=shape, is_training=False)
|
||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||
nhwc = self._runBatchNormalizationWithFormat(
|
||||
data_format='NHWC', shape=shape, is_training=False)
|
||||
nchw = self._runBatchNormalizationWithFormat(
|
||||
data_format='NCHW', shape=shape, is_training=False)
|
||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testNHWCAndNCHWTrainingProduceSameOutput(self):
|
||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||
nhwc = self._runBatchNormalizationWithFormat(
|
||||
data_format='NHWC', shape=shape, is_training=True)
|
||||
nchw = self._runBatchNormalizationWithFormat(
|
||||
data_format='NCHW', shape=shape, is_training=True)
|
||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||
nhwc = self._runBatchNormalizationWithFormat(
|
||||
data_format='NHWC', shape=shape, is_training=True)
|
||||
nchw = self._runBatchNormalizationWithFormat(
|
||||
data_format='NCHW', shape=shape, is_training=True)
|
||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
class LayerNormTest(test.TestCase):
|
||||
|
@ -220,7 +220,7 @@ def LogisticClassifier(inputs):
|
||||
|
||||
|
||||
def BatchNormClassifier(inputs):
|
||||
inputs = layers.batch_norm(inputs, decay=0.1)
|
||||
inputs = layers.batch_norm(inputs, decay=0.1, fused=None)
|
||||
return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
|
||||
|
||||
|
||||
@ -267,6 +267,11 @@ class CreateTrainOpTest(test.TestCase):
|
||||
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
def _addBesselsCorrection(self, sample_size, expected_var):
|
||||
correction_factor = sample_size / (sample_size - 1)
|
||||
expected_var *= correction_factor
|
||||
return expected_var
|
||||
|
||||
def testUseUpdateOps(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(0)
|
||||
@ -275,6 +280,7 @@ class CreateTrainOpTest(test.TestCase):
|
||||
|
||||
expected_mean = np.mean(self._inputs, axis=(0))
|
||||
expected_var = np.var(self._inputs, axis=(0))
|
||||
expected_var = self._addBesselsCorrection(16, expected_var)
|
||||
|
||||
tf_predictions = BatchNormClassifier(tf_inputs)
|
||||
loss_ops.log_loss(tf_predictions, tf_labels)
|
||||
|
@ -123,6 +123,10 @@ class BatchNormalization(base.Layer):
|
||||
if self.fused and renorm:
|
||||
raise ValueError(
|
||||
'Batch renorm is currently not supported with fused batch norm.')
|
||||
if self.fused and (beta_regularizer is not None or
|
||||
gamma_regularizer is not None):
|
||||
raise ValueError('Regularizers are not currently '
|
||||
'supported for fused batch norm.')
|
||||
if renorm:
|
||||
renorm_clipping = renorm_clipping or {}
|
||||
keys = ['rmax', 'rmin', 'dmax']
|
||||
@ -153,7 +157,12 @@ class BatchNormalization(base.Layer):
|
||||
' is out of range for input with rank ' + str(ndim))
|
||||
|
||||
if self.fused is None:
|
||||
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
|
||||
# Currently fused batch norm doesn't support renorm and beta/gamma
|
||||
# regularizer; and only supports an input tensor of rank 4 and a channel
|
||||
# dimension on axis 1 and 3.
|
||||
self.fused = not self.renorm and ndim == 4 and axis in [
|
||||
1, 3
|
||||
] and self.beta_regularizer is None and self.gamma_regularizer is None
|
||||
|
||||
if self.fused:
|
||||
if axis == 1:
|
||||
|
@ -143,45 +143,46 @@ class BNTest(test.TestCase):
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
def test4DInputAxis1(self):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
axis=1, epsilon=epsilon, momentum=0.9)
|
||||
inputs = variables.Variable(
|
||||
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||
training = array_ops.placeholder(dtype='bool')
|
||||
outputs = bn.apply(inputs, training=training)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
axis=1, epsilon=epsilon, momentum=0.9)
|
||||
inputs = variables.Variable(
|
||||
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||
training = array_ops.placeholder(dtype='bool')
|
||||
outputs = bn.apply(inputs, training=training)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test training with placeholder learning phase.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
||||
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
||||
for _ in range(100):
|
||||
np_output, _, _ = sess.run([outputs] + bn.updates,
|
||||
feed_dict={training: True})
|
||||
# Verify that the axis is normalized during training.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
# Test training with placeholder learning phase.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
||||
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
||||
for _ in range(100):
|
||||
np_output, _, _ = sess.run(
|
||||
[outputs] + bn.updates, feed_dict={training: True})
|
||||
# Verify that the axis is normalized during training.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
# Verify that the statistics are updated during training.
|
||||
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||
np_inputs = sess.run(inputs)
|
||||
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
||||
std = np.std(np_inputs, axis=(0, 2, 3))
|
||||
variance = np.square(std)
|
||||
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||
|
||||
# Test inference with placeholder learning phase.
|
||||
np_output = sess.run(outputs, feed_dict={training: False})
|
||||
|
||||
# Verify that the axis is normalized during inference.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
# Verify that the statistics are updated during training.
|
||||
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||
np_inputs = sess.run(inputs)
|
||||
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
||||
std = np.std(np_inputs, axis=(0, 2, 3))
|
||||
variance = np.square(std)
|
||||
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||
|
||||
# Test inference with placeholder learning phase.
|
||||
np_output = sess.run(outputs, feed_dict={training: False})
|
||||
|
||||
# Verify that the axis is normalized during inference.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
def test4DInputAxis2(self):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
|
Loading…
Reference in New Issue
Block a user