Support leaving the offset (beta) parameter out in batch_normalization, in which case no offset will be added after normalization.

Change: 115489328
This commit is contained in:
A. Unique TensorFlower 2016-02-24 13:53:59 -08:00 committed by TensorFlower Gardener
parent 87a289103f
commit 4afef14f02
2 changed files with 91 additions and 68 deletions

View File

@ -593,8 +593,8 @@ def batch_normalization(x,
x: Input `Tensor` of arbitrary dimensionality.
mean: A mean `Tensor`.
variance: A variance `Tensor`.
offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, to
be applied to the normalized tensor.
offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, or
None. If present, will be added to the normalized tensor.
scale: A scale `Tensor`, often denoted \\\\(\gamma\\\\) in equations, or
`None`. If present, the scale is applied to the normalized tensor.
variance_epsilon: A small float number to avoid dividing by 0.
@ -607,7 +607,7 @@ def batch_normalization(x,
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
inv *= scale
return x * inv + (offset - mean * inv)
return x * inv + (offset - mean * inv if offset else -mean * inv)
def batch_norm_with_global_normalization(t,

View File

@ -479,19 +479,17 @@ class DropoutTest(tf.test.TestCase):
class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
scale_after_normalization, shift_after_normalization):
y = (x - m) / np.sqrt(v + epsilon)
y = y * gamma if scale_after_normalization else y
y += beta
return y
return y + beta if shift_after_normalization else y
def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
scale_after_normalization, shift_after_normalization):
y = (x - m) * tf.rsqrt(v + epsilon)
if scale_after_normalization:
y = gamma * y
y += beta
return y
return y + beta if shift_after_normalization else y
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
@ -510,10 +508,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
x, m, v, beta, gamma, epsilon, scale_after_normalization)
def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
scale_after_normalization, shift_after_normalization):
"""New implementation."""
return tf.nn.batch_normalization(
x, m, v, beta, gamma if scale_after_normalization else None, epsilon)
x, m, v, beta if shift_after_normalization else None,
gamma if scale_after_normalization else None, epsilon)
def testBatchNorm(self):
x_shape = [3, 5, 4, 2]
@ -532,29 +531,35 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
gamma = tf.constant(gamma_val, name="gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
bn2 = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
bn1bw = self._tfBatchNormV1BW(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
bn1 = self._tfBatchNormV1(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
on = self._opsBatchNorm(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
np_bn = self._npBatchNorm(
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
scale_after_normalization)
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
[bn2, bn1bw, bn1, on])
self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
self.assertAllClose(np_bn, ops_bn, atol=0.000001)
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
for shift_after_normalization in [True, False]:
bn2 = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization,
shift_after_normalization)
bn1bw = self._tfBatchNormV1BW(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
bn1 = self._tfBatchNormV1(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
on = self._opsBatchNorm(
x, m, v, beta, gamma, epsilon, scale_after_normalization,
shift_after_normalization)
np_bn = self._npBatchNorm(
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
scale_after_normalization, shift_after_normalization)
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
[bn2, bn1bw, bn1, on])
self.assertAllClose(np_bn, ops_bn, atol=0.000001)
self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
# shift_after_normalization=False is not supported in v1.
if shift_after_normalization:
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
version, err_tolerance=1e-11):
shift_after_normalization, version,
err_tolerance=1e-11):
x_shape = [3, 5, 4, 5]
param_shape = [5]
np.random.seed(1) # Make it reproducible.
@ -575,7 +580,8 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
x, m, v, beta, gamma, epsilon, scale_after_normalization)
elif version == 2:
output = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
x, m, v, beta, gamma, epsilon, scale_after_normalization,
shift_after_normalization)
else:
print("Invalid version", version)
raise
@ -583,31 +589,39 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
err = tf.test.compute_gradient_error(
all_params[param_index], all_shapes[param_index], output, x_shape)
print("Batch normalization v%d %s gradient %s scale err = " %
(version, tag, "with" if scale_after_normalization else "without"),
print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
(version, tag, "with" if scale_after_normalization else "without",
"with" if shift_after_normalization else "without"),
err)
self.assertLess(err, err_tolerance)
def testBatchNormInputGradient(self):
def _testBatchNormGradientInAllNeedConfigs(
self, param_index, tag, err_tolerance=1e-11):
for scale_after_normalization in [True, False]:
for v in [1, 2]:
self._testBatchNormGradient(0, "x", scale_after_normalization, v)
for shift_after_normalization in [True, False]:
# shift_after_normalization=False is not supported in version 1.
for v in ([1, 2] if shift_after_normalization else [2]):
self._testBatchNormGradient(
param_index, tag, scale_after_normalization,
shift_after_normalization, v, err_tolerance)
def testBatchNormInputGradient(self):
self._testBatchNormGradientInAllNeedConfigs(0, "x")
def testBatchNormMeanGradient(self):
for scale_after_normalization in [True, False]:
for v in [1, 2]:
self._testBatchNormGradient(1, "mean", scale_after_normalization, v)
self._testBatchNormGradientInAllNeedConfigs(1, "mean")
def testBatchNormVarianceGradient(self):
for scale_after_normalization in [True, False]:
for v in [1, 2]:
self._testBatchNormGradient(2, "variance", scale_after_normalization, v,
err_tolerance=1e-03)
self._testBatchNormGradientInAllNeedConfigs(2, "variance",
err_tolerance=1e-03)
def testBatchNormBetaGradient(self):
# Since beta does not exist when scale_after_normalization=False, we only
# test for scale_after_normalization=True.
for scale_after_normalization in [True, False]:
for v in [1, 2]:
self._testBatchNormGradient(3, "beta", scale_after_normalization, v)
self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
v)
def testBatchNormGammaGradient(self):
# If scale_after_normalization is False, backprop for gamma in v1
@ -615,8 +629,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
# gamma is not used at all, and the gradient is None, which displeases the
# gradient checker.
for scale_after_normalization in [True, False]:
self._testBatchNormGradient(4, "gamma", scale_after_normalization, 1)
self._testBatchNormGradient(4, "gamma", True, 2)
self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
1)
for shift_after_normalization in [True, False]:
self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
2)
def testBatchNormGradImpl(self):
x_shape = [7, 5, 4, 6]
@ -644,7 +661,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
gen_nn_ops._batch_norm_with_global_normalization_grad(
x, m, v, gamma, backprop, epsilon, scale_after_normalization))
on = self._opsBatchNorm(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
x, m, v, beta, gamma, epsilon, scale_after_normalization, True)
odx, odm, odv, odb, odg = tf.gradients(
[on], [x, m, v, beta, gamma], [backprop])
if scale_after_normalization:
@ -687,16 +704,20 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
gamma, keep_dims_param_shape, name="keep_dims_gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
bn = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
keep_dims_bn = self._tfBatchNormV2(
x, keep_dims_m, keep_dims_v, keep_dims_beta,
keep_dims_gamma, epsilon, scale_after_normalization)
tf_batch_norm, keep_dims_tf_batch_norm = sess.run([bn, keep_dims_bn])
self.assertEquals(x_shape, tf_batch_norm.shape)
self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
self.assertAllClose(
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
for shift_after_normalization in [True, False]:
bn = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization,
shift_after_normalization)
keep_dims_bn = self._tfBatchNormV2(
x, keep_dims_m, keep_dims_v, keep_dims_beta,
keep_dims_gamma, epsilon, scale_after_normalization,
shift_after_normalization)
tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
[bn, keep_dims_bn])
self.assertEquals(x_shape, tf_batch_norm.shape)
self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
self.assertAllClose(
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.000001):
x_val = np.random.random_sample(x_shape).astype(np.float32)
@ -713,15 +734,17 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
gamma = tf.constant(gamma_val, name="gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
bn = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
np_batch_norm = self._npBatchNorm(
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
scale_after_normalization)
[tf_batch_norm] = sess.run([bn])
self.assertEquals(x_shape, np_batch_norm.shape)
self.assertEquals(x_shape, tf_batch_norm.shape)
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
for shift_after_normalization in [True, False]:
bn = self._tfBatchNormV2(
x, m, v, beta, gamma, epsilon, scale_after_normalization,
shift_after_normalization)
np_batch_norm = self._npBatchNorm(
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
scale_after_normalization, shift_after_normalization)
[tf_batch_norm] = sess.run([bn])
self.assertEquals(x_shape, np_batch_norm.shape)
self.assertEquals(x_shape, tf_batch_norm.shape)
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
def testBatchNormArbitraryShapes(self):
"""Test for a variety of shapes and moments.