Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible).
PiperOrigin-RevId: 159649743
This commit is contained in:
parent
a4a4698323
commit
b3f33ad466
@ -7,6 +7,7 @@ exports_files(["LICENSE"])
|
|||||||
|
|
||||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
@ -393,12 +394,11 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "normalization_test",
|
name = "normalization_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["python/keras/layers/normalization_test.py"],
|
srcs = ["python/keras/layers/normalization_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
additional_deps = [
|
||||||
deps = [
|
|
||||||
":keras",
|
":keras",
|
||||||
":testing_utils",
|
":testing_utils",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -94,22 +94,23 @@ class NoiseLayersTest(test.TestCase):
|
|||||||
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||||
|
|
||||||
def test_batchnorm_convnet(self):
|
def test_batchnorm_convnet(self):
|
||||||
with self.test_session():
|
if test.is_gpu_available(cuda_only=True):
|
||||||
model = keras.models.Sequential()
|
with self.test_session(use_gpu=True):
|
||||||
norm = keras.layers.BatchNormalization(
|
model = keras.models.Sequential()
|
||||||
axis=1, input_shape=(3, 4, 4), momentum=0.8)
|
norm = keras.layers.BatchNormalization(
|
||||||
model.add(norm)
|
axis=1, input_shape=(3, 4, 4), momentum=0.8)
|
||||||
model.compile(loss='mse', optimizer='sgd')
|
model.add(norm)
|
||||||
|
model.compile(loss='mse', optimizer='sgd')
|
||||||
|
|
||||||
# centered on 5.0, variance 10.0
|
# centered on 5.0, variance 10.0
|
||||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
|
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
|
||||||
model.fit(x, x, epochs=4, verbose=0)
|
model.fit(x, x, epochs=4, verbose=0)
|
||||||
out = model.predict(x)
|
out = model.predict(x)
|
||||||
out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
|
out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
|
||||||
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
|
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
|
||||||
|
|
||||||
np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
|
np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
|
||||||
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
|
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
|
||||||
|
|
||||||
def test_shared_batchnorm(self):
|
def test_shared_batchnorm(self):
|
||||||
"""Test that a BN layer can be shared across different data streams.
|
"""Test that a BN layer can be shared across different data streams.
|
||||||
|
@ -257,27 +257,33 @@ def _fused_batch_norm(
|
|||||||
'beta')
|
'beta')
|
||||||
if not param_initializers:
|
if not param_initializers:
|
||||||
param_initializers = {}
|
param_initializers = {}
|
||||||
beta_initializer = param_initializers.get('beta',
|
if center:
|
||||||
init_ops.zeros_initializer())
|
beta_initializer = param_initializers.get('beta',
|
||||||
beta = variables.model_variable(
|
init_ops.zeros_initializer())
|
||||||
'beta',
|
beta = variables.model_variable(
|
||||||
shape=params_shape,
|
'beta',
|
||||||
dtype=dtype,
|
shape=params_shape,
|
||||||
initializer=beta_initializer,
|
dtype=dtype,
|
||||||
collections=beta_collections,
|
initializer=beta_initializer,
|
||||||
trainable=trainable_beta)
|
collections=beta_collections,
|
||||||
trainable_gamma = trainable and scale
|
trainable=trainable_beta)
|
||||||
gamma_collections = utils.get_variable_collections(variables_collections,
|
else:
|
||||||
'gamma')
|
beta = array_ops.constant(0.0, shape=params_shape)
|
||||||
gamma_initializer = param_initializers.get('gamma',
|
|
||||||
init_ops.ones_initializer())
|
if scale:
|
||||||
gamma = variables.model_variable(
|
gamma_collections = utils.get_variable_collections(
|
||||||
'gamma',
|
variables_collections, 'gamma')
|
||||||
shape=params_shape,
|
gamma_initializer = param_initializers.get('gamma',
|
||||||
dtype=dtype,
|
init_ops.ones_initializer())
|
||||||
initializer=gamma_initializer,
|
gamma = variables.model_variable(
|
||||||
collections=gamma_collections,
|
'gamma',
|
||||||
trainable=trainable_gamma)
|
shape=params_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=gamma_initializer,
|
||||||
|
collections=gamma_collections,
|
||||||
|
trainable=trainable)
|
||||||
|
else:
|
||||||
|
gamma = array_ops.constant(1.0, shape=params_shape)
|
||||||
|
|
||||||
# Create moving_mean and moving_variance variables and add them to the
|
# Create moving_mean and moving_variance variables and add them to the
|
||||||
# appropriate collections.
|
# appropriate collections.
|
||||||
@ -449,7 +455,8 @@ def batch_norm(inputs,
|
|||||||
then the batch normalization uses weighted mean and
|
then the batch normalization uses weighted mean and
|
||||||
variance. (This can be used to correct for bias in training
|
variance. (This can be used to correct for bias in training
|
||||||
example selection.)
|
example selection.)
|
||||||
fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
|
fused: if `True`, use a faster, fused implementation based on
|
||||||
|
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||||
data_format: A string. `NHWC` (default) and `NCHW` are supported.
|
data_format: A string. `NHWC` (default) and `NCHW` are supported.
|
||||||
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
|
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
|
||||||
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
|
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
|
||||||
@ -473,7 +480,6 @@ def batch_norm(inputs,
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `batch_weights` is not None and `fused` is True.
|
ValueError: If `batch_weights` is not None and `fused` is True.
|
||||||
ValueError: If `param_regularizers` is not None and `fused` is True.
|
|
||||||
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
|
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
|
||||||
ValueError: If the rank of `inputs` is undefined.
|
ValueError: If the rank of `inputs` is undefined.
|
||||||
ValueError: If rank or channels dimension of `inputs` is undefined.
|
ValueError: If rank or channels dimension of `inputs` is undefined.
|
||||||
@ -487,6 +493,21 @@ def batch_norm(inputs,
|
|||||||
'supported for fused batch norm.')
|
'supported for fused batch norm.')
|
||||||
if renorm:
|
if renorm:
|
||||||
raise ValueError('Renorm is not supported for fused batch norm.')
|
raise ValueError('Renorm is not supported for fused batch norm.')
|
||||||
|
|
||||||
|
# Only use _fused_batch_norm (1) if fused is set True or if it is
|
||||||
|
# possible to use (currently it doesn't support batch weights,
|
||||||
|
# renorm, and the case when rank is neither 2 nor 4),
|
||||||
|
# and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
|
||||||
|
# or non-default updates_collections (not implemented in
|
||||||
|
# normalization_layers.BatchNormalization yet); otherwise use the fused
|
||||||
|
# implementation in normalization_layers.BatchNormalization.
|
||||||
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
|
rank = inputs.get_shape().ndims
|
||||||
|
feature_supported = batch_weights is None and not renorm and rank in [2, 4]
|
||||||
|
possible_to_fuse = fused is None and feature_supported
|
||||||
|
if (fused or possible_to_fuse) and (
|
||||||
|
zero_debias_moving_mean or rank == 2 or
|
||||||
|
updates_collections is not ops.GraphKeys.UPDATE_OPS):
|
||||||
return _fused_batch_norm(
|
return _fused_batch_norm(
|
||||||
inputs,
|
inputs,
|
||||||
decay=decay,
|
decay=decay,
|
||||||
@ -552,7 +573,8 @@ def batch_norm(inputs,
|
|||||||
renorm_momentum=renorm_decay,
|
renorm_momentum=renorm_decay,
|
||||||
name=sc.name,
|
name=sc.name,
|
||||||
_scope=sc,
|
_scope=sc,
|
||||||
_reuse=reuse)
|
_reuse=reuse,
|
||||||
|
fused=fused)
|
||||||
outputs = layer.apply(inputs, training=is_training)
|
outputs = layer.apply(inputs, training=is_training)
|
||||||
|
|
||||||
# Add variables to collections.
|
# Add variables to collections.
|
||||||
|
@ -1703,13 +1703,6 @@ class BatchNormTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'):
|
with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'):
|
||||||
_layers.batch_norm(inputs, batch_weights=batch_weights, fused=True)
|
_layers.batch_norm(inputs, batch_weights=batch_weights, fused=True)
|
||||||
|
|
||||||
def testParamRegularizersFused(self):
|
|
||||||
with ops.Graph().as_default() as g, self.test_session(g):
|
|
||||||
inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7))
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
'Regularizers are not currently'):
|
|
||||||
_layers.batch_norm(inputs, param_regularizers={}, fused=True)
|
|
||||||
|
|
||||||
def _testCreateOp(self, fused):
|
def _testCreateOp(self, fused):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -1780,7 +1773,8 @@ class BatchNormTest(test.TestCase):
|
|||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||||
_layers.batch_norm(images, scale=True, zero_debias_moving_mean=True)
|
_layers.batch_norm(
|
||||||
|
images, scale=True, zero_debias_moving_mean=True, fused=False)
|
||||||
self.assertEqual(len(variables.get_model_variables()), 6)
|
self.assertEqual(len(variables.get_model_variables()), 6)
|
||||||
moving_mean = variables.get_variables_by_name('moving_mean')[0]
|
moving_mean = variables.get_variables_by_name('moving_mean')[0]
|
||||||
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
||||||
@ -1874,7 +1868,8 @@ class BatchNormTest(test.TestCase):
|
|||||||
images,
|
images,
|
||||||
decay=0.1,
|
decay=0.1,
|
||||||
updates_collections=None,
|
updates_collections=None,
|
||||||
zero_debias_moving_mean=True)
|
zero_debias_moving_mean=True,
|
||||||
|
fused=False)
|
||||||
moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0]
|
moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0]
|
||||||
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
moving_variance = variables.get_variables_by_name('moving_variance')[0]
|
||||||
biased = variables.get_variables_by_name('biased')[0]
|
biased = variables.get_variables_by_name('biased')[0]
|
||||||
@ -2523,7 +2518,7 @@ class BatchNormTest(test.TestCase):
|
|||||||
|
|
||||||
def _runBatchNormalizationWithFormat(self, shape, data_format, is_training):
|
def _runBatchNormalizationWithFormat(self, shape, data_format, is_training):
|
||||||
channels = shape[-1]
|
channels = shape[-1]
|
||||||
with self.test_session() as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
images = np.arange(np.product(shape), dtype=np.float32).reshape(shape)
|
images = np.arange(np.product(shape), dtype=np.float32).reshape(shape)
|
||||||
beta = init_ops.constant_initializer(
|
beta = init_ops.constant_initializer(
|
||||||
np.arange(
|
np.arange(
|
||||||
@ -2561,20 +2556,22 @@ class BatchNormTest(test.TestCase):
|
|||||||
return sess.run(output)
|
return sess.run(output)
|
||||||
|
|
||||||
def testNHWCAndNCHWInferenceProduceSameOutput(self):
|
def testNHWCAndNCHWInferenceProduceSameOutput(self):
|
||||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
if test.is_gpu_available(cuda_only=True):
|
||||||
nhwc = self._runBatchNormalizationWithFormat(
|
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||||
data_format='NHWC', shape=shape, is_training=False)
|
nhwc = self._runBatchNormalizationWithFormat(
|
||||||
nchw = self._runBatchNormalizationWithFormat(
|
data_format='NHWC', shape=shape, is_training=False)
|
||||||
data_format='NCHW', shape=shape, is_training=False)
|
nchw = self._runBatchNormalizationWithFormat(
|
||||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
data_format='NCHW', shape=shape, is_training=False)
|
||||||
|
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
def testNHWCAndNCHWTrainingProduceSameOutput(self):
|
def testNHWCAndNCHWTrainingProduceSameOutput(self):
|
||||||
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
if test.is_gpu_available(cuda_only=True):
|
||||||
nhwc = self._runBatchNormalizationWithFormat(
|
for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]:
|
||||||
data_format='NHWC', shape=shape, is_training=True)
|
nhwc = self._runBatchNormalizationWithFormat(
|
||||||
nchw = self._runBatchNormalizationWithFormat(
|
data_format='NHWC', shape=shape, is_training=True)
|
||||||
data_format='NCHW', shape=shape, is_training=True)
|
nchw = self._runBatchNormalizationWithFormat(
|
||||||
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
data_format='NCHW', shape=shape, is_training=True)
|
||||||
|
self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
class LayerNormTest(test.TestCase):
|
class LayerNormTest(test.TestCase):
|
||||||
|
@ -220,7 +220,7 @@ def LogisticClassifier(inputs):
|
|||||||
|
|
||||||
|
|
||||||
def BatchNormClassifier(inputs):
|
def BatchNormClassifier(inputs):
|
||||||
inputs = layers.batch_norm(inputs, decay=0.1)
|
inputs = layers.batch_norm(inputs, decay=0.1, fused=None)
|
||||||
return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
|
return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
|
||||||
|
|
||||||
|
|
||||||
@ -267,6 +267,11 @@ class CreateTrainOpTest(test.TestCase):
|
|||||||
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
def _addBesselsCorrection(self, sample_size, expected_var):
|
||||||
|
correction_factor = sample_size / (sample_size - 1)
|
||||||
|
expected_var *= correction_factor
|
||||||
|
return expected_var
|
||||||
|
|
||||||
def testUseUpdateOps(self):
|
def testUseUpdateOps(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
@ -275,6 +280,7 @@ class CreateTrainOpTest(test.TestCase):
|
|||||||
|
|
||||||
expected_mean = np.mean(self._inputs, axis=(0))
|
expected_mean = np.mean(self._inputs, axis=(0))
|
||||||
expected_var = np.var(self._inputs, axis=(0))
|
expected_var = np.var(self._inputs, axis=(0))
|
||||||
|
expected_var = self._addBesselsCorrection(16, expected_var)
|
||||||
|
|
||||||
tf_predictions = BatchNormClassifier(tf_inputs)
|
tf_predictions = BatchNormClassifier(tf_inputs)
|
||||||
loss_ops.log_loss(tf_predictions, tf_labels)
|
loss_ops.log_loss(tf_predictions, tf_labels)
|
||||||
|
@ -123,6 +123,10 @@ class BatchNormalization(base.Layer):
|
|||||||
if self.fused and renorm:
|
if self.fused and renorm:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Batch renorm is currently not supported with fused batch norm.')
|
'Batch renorm is currently not supported with fused batch norm.')
|
||||||
|
if self.fused and (beta_regularizer is not None or
|
||||||
|
gamma_regularizer is not None):
|
||||||
|
raise ValueError('Regularizers are not currently '
|
||||||
|
'supported for fused batch norm.')
|
||||||
if renorm:
|
if renorm:
|
||||||
renorm_clipping = renorm_clipping or {}
|
renorm_clipping = renorm_clipping or {}
|
||||||
keys = ['rmax', 'rmin', 'dmax']
|
keys = ['rmax', 'rmin', 'dmax']
|
||||||
@ -153,7 +157,12 @@ class BatchNormalization(base.Layer):
|
|||||||
' is out of range for input with rank ' + str(ndim))
|
' is out of range for input with rank ' + str(ndim))
|
||||||
|
|
||||||
if self.fused is None:
|
if self.fused is None:
|
||||||
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
|
# Currently fused batch norm doesn't support renorm and beta/gamma
|
||||||
|
# regularizer; and only supports an input tensor of rank 4 and a channel
|
||||||
|
# dimension on axis 1 and 3.
|
||||||
|
self.fused = not self.renorm and ndim == 4 and axis in [
|
||||||
|
1, 3
|
||||||
|
] and self.beta_regularizer is None and self.gamma_regularizer is None
|
||||||
|
|
||||||
if self.fused:
|
if self.fused:
|
||||||
if axis == 1:
|
if axis == 1:
|
||||||
|
@ -143,45 +143,46 @@ class BNTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
def test4DInputAxis1(self):
|
def test4DInputAxis1(self):
|
||||||
epsilon = 1e-3
|
if test.is_gpu_available(cuda_only=True):
|
||||||
bn = normalization_layers.BatchNormalization(
|
epsilon = 1e-3
|
||||||
axis=1, epsilon=epsilon, momentum=0.9)
|
bn = normalization_layers.BatchNormalization(
|
||||||
inputs = variables.Variable(
|
axis=1, epsilon=epsilon, momentum=0.9)
|
||||||
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
inputs = variables.Variable(
|
||||||
training = array_ops.placeholder(dtype='bool')
|
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||||
outputs = bn.apply(inputs, training=training)
|
training = array_ops.placeholder(dtype='bool')
|
||||||
|
outputs = bn.apply(inputs, training=training)
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
# Test training with placeholder learning phase.
|
# Test training with placeholder learning phase.
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||||
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
||||||
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
np_output, _, _ = sess.run([outputs] + bn.updates,
|
np_output, _, _ = sess.run(
|
||||||
feed_dict={training: True})
|
[outputs] + bn.updates, feed_dict={training: True})
|
||||||
# Verify that the axis is normalized during training.
|
# Verify that the axis is normalized during training.
|
||||||
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
# Verify that the statistics are updated during training.
|
||||||
|
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||||
|
np_inputs = sess.run(inputs)
|
||||||
|
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
||||||
|
std = np.std(np_inputs, axis=(0, 2, 3))
|
||||||
|
variance = np.square(std)
|
||||||
|
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||||
|
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||||
|
|
||||||
|
# Test inference with placeholder learning phase.
|
||||||
|
np_output = sess.run(outputs, feed_dict={training: False})
|
||||||
|
|
||||||
|
# Verify that the axis is normalized during inference.
|
||||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
# Verify that the statistics are updated during training.
|
|
||||||
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
|
||||||
np_inputs = sess.run(inputs)
|
|
||||||
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
|
||||||
std = np.std(np_inputs, axis=(0, 2, 3))
|
|
||||||
variance = np.square(std)
|
|
||||||
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
|
||||||
self.assertAllClose(variance, moving_var, atol=1e-2)
|
|
||||||
|
|
||||||
# Test inference with placeholder learning phase.
|
|
||||||
np_output = sess.run(outputs, feed_dict={training: False})
|
|
||||||
|
|
||||||
# Verify that the axis is normalized during inference.
|
|
||||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
|
||||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
|
||||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
|
||||||
|
|
||||||
def test4DInputAxis2(self):
|
def test4DInputAxis2(self):
|
||||||
epsilon = 1e-3
|
epsilon = 1e-3
|
||||||
bn = normalization_layers.BatchNormalization(
|
bn = normalization_layers.BatchNormalization(
|
||||||
|
Loading…
Reference in New Issue
Block a user