Support leaving the offset (beta) parameter out in batch_normalization, in which case no offset will be added after normalization.
Change: 115489328
This commit is contained in:
parent
87a289103f
commit
4afef14f02
@ -593,8 +593,8 @@ def batch_normalization(x,
|
||||
x: Input `Tensor` of arbitrary dimensionality.
|
||||
mean: A mean `Tensor`.
|
||||
variance: A variance `Tensor`.
|
||||
offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, to
|
||||
be applied to the normalized tensor.
|
||||
offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, or
|
||||
None. If present, will be added to the normalized tensor.
|
||||
scale: A scale `Tensor`, often denoted \\\\(\gamma\\\\) in equations, or
|
||||
`None`. If present, the scale is applied to the normalized tensor.
|
||||
variance_epsilon: A small float number to avoid dividing by 0.
|
||||
@ -607,7 +607,7 @@ def batch_normalization(x,
|
||||
inv = math_ops.rsqrt(variance + variance_epsilon)
|
||||
if scale is not None:
|
||||
inv *= scale
|
||||
return x * inv + (offset - mean * inv)
|
||||
return x * inv + (offset - mean * inv if offset else -mean * inv)
|
||||
|
||||
|
||||
def batch_norm_with_global_normalization(t,
|
||||
|
@ -479,19 +479,17 @@ class DropoutTest(tf.test.TestCase):
|
||||
class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
|
||||
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
|
||||
scale_after_normalization):
|
||||
scale_after_normalization, shift_after_normalization):
|
||||
y = (x - m) / np.sqrt(v + epsilon)
|
||||
y = y * gamma if scale_after_normalization else y
|
||||
y += beta
|
||||
return y
|
||||
return y + beta if shift_after_normalization else y
|
||||
|
||||
def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
|
||||
scale_after_normalization):
|
||||
scale_after_normalization, shift_after_normalization):
|
||||
y = (x - m) * tf.rsqrt(v + epsilon)
|
||||
if scale_after_normalization:
|
||||
y = gamma * y
|
||||
y += beta
|
||||
return y
|
||||
return y + beta if shift_after_normalization else y
|
||||
|
||||
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
|
||||
scale_after_normalization):
|
||||
@ -510,10 +508,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
|
||||
def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
|
||||
scale_after_normalization):
|
||||
scale_after_normalization, shift_after_normalization):
|
||||
"""New implementation."""
|
||||
return tf.nn.batch_normalization(
|
||||
x, m, v, beta, gamma if scale_after_normalization else None, epsilon)
|
||||
x, m, v, beta if shift_after_normalization else None,
|
||||
gamma if scale_after_normalization else None, epsilon)
|
||||
|
||||
def testBatchNorm(self):
|
||||
x_shape = [3, 5, 4, 2]
|
||||
@ -532,29 +531,35 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
gamma = tf.constant(gamma_val, name="gamma")
|
||||
epsilon = 0.001
|
||||
for scale_after_normalization in [True, False]:
|
||||
bn2 = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
bn1bw = self._tfBatchNormV1BW(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
bn1 = self._tfBatchNormV1(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
on = self._opsBatchNorm(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
np_bn = self._npBatchNorm(
|
||||
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
||||
scale_after_normalization)
|
||||
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
|
||||
[bn2, bn1bw, bn1, on])
|
||||
self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
|
||||
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
|
||||
self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
|
||||
self.assertAllClose(np_bn, ops_bn, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
|
||||
for shift_after_normalization in [True, False]:
|
||||
bn2 = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
bn1bw = self._tfBatchNormV1BW(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
bn1 = self._tfBatchNormV1(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
on = self._opsBatchNorm(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
np_bn = self._npBatchNorm(
|
||||
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
||||
scale_after_normalization, shift_after_normalization)
|
||||
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
|
||||
[bn2, bn1bw, bn1, on])
|
||||
self.assertAllClose(np_bn, ops_bn, atol=0.000001)
|
||||
self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
|
||||
# shift_after_normalization=False is not supported in v1.
|
||||
if shift_after_normalization:
|
||||
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
|
||||
self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
|
||||
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
|
||||
|
||||
def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
|
||||
version, err_tolerance=1e-11):
|
||||
shift_after_normalization, version,
|
||||
err_tolerance=1e-11):
|
||||
x_shape = [3, 5, 4, 5]
|
||||
param_shape = [5]
|
||||
np.random.seed(1) # Make it reproducible.
|
||||
@ -575,7 +580,8 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
elif version == 2:
|
||||
output = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
else:
|
||||
print("Invalid version", version)
|
||||
raise
|
||||
@ -583,31 +589,39 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
|
||||
err = tf.test.compute_gradient_error(
|
||||
all_params[param_index], all_shapes[param_index], output, x_shape)
|
||||
print("Batch normalization v%d %s gradient %s scale err = " %
|
||||
(version, tag, "with" if scale_after_normalization else "without"),
|
||||
print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
|
||||
(version, tag, "with" if scale_after_normalization else "without",
|
||||
"with" if shift_after_normalization else "without"),
|
||||
err)
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
def testBatchNormInputGradient(self):
|
||||
def _testBatchNormGradientInAllNeedConfigs(
|
||||
self, param_index, tag, err_tolerance=1e-11):
|
||||
for scale_after_normalization in [True, False]:
|
||||
for v in [1, 2]:
|
||||
self._testBatchNormGradient(0, "x", scale_after_normalization, v)
|
||||
for shift_after_normalization in [True, False]:
|
||||
# shift_after_normalization=False is not supported in version 1.
|
||||
for v in ([1, 2] if shift_after_normalization else [2]):
|
||||
self._testBatchNormGradient(
|
||||
param_index, tag, scale_after_normalization,
|
||||
shift_after_normalization, v, err_tolerance)
|
||||
|
||||
def testBatchNormInputGradient(self):
|
||||
self._testBatchNormGradientInAllNeedConfigs(0, "x")
|
||||
|
||||
def testBatchNormMeanGradient(self):
|
||||
for scale_after_normalization in [True, False]:
|
||||
for v in [1, 2]:
|
||||
self._testBatchNormGradient(1, "mean", scale_after_normalization, v)
|
||||
self._testBatchNormGradientInAllNeedConfigs(1, "mean")
|
||||
|
||||
def testBatchNormVarianceGradient(self):
|
||||
for scale_after_normalization in [True, False]:
|
||||
for v in [1, 2]:
|
||||
self._testBatchNormGradient(2, "variance", scale_after_normalization, v,
|
||||
err_tolerance=1e-03)
|
||||
self._testBatchNormGradientInAllNeedConfigs(2, "variance",
|
||||
err_tolerance=1e-03)
|
||||
|
||||
def testBatchNormBetaGradient(self):
|
||||
# Since beta does not exist when scale_after_normalization=False, we only
|
||||
# test for scale_after_normalization=True.
|
||||
for scale_after_normalization in [True, False]:
|
||||
for v in [1, 2]:
|
||||
self._testBatchNormGradient(3, "beta", scale_after_normalization, v)
|
||||
self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
|
||||
v)
|
||||
|
||||
def testBatchNormGammaGradient(self):
|
||||
# If scale_after_normalization is False, backprop for gamma in v1
|
||||
@ -615,8 +629,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
# gamma is not used at all, and the gradient is None, which displeases the
|
||||
# gradient checker.
|
||||
for scale_after_normalization in [True, False]:
|
||||
self._testBatchNormGradient(4, "gamma", scale_after_normalization, 1)
|
||||
self._testBatchNormGradient(4, "gamma", True, 2)
|
||||
self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
|
||||
1)
|
||||
for shift_after_normalization in [True, False]:
|
||||
self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
|
||||
2)
|
||||
|
||||
def testBatchNormGradImpl(self):
|
||||
x_shape = [7, 5, 4, 6]
|
||||
@ -644,7 +661,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
gen_nn_ops._batch_norm_with_global_normalization_grad(
|
||||
x, m, v, gamma, backprop, epsilon, scale_after_normalization))
|
||||
on = self._opsBatchNorm(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization, True)
|
||||
odx, odm, odv, odb, odg = tf.gradients(
|
||||
[on], [x, m, v, beta, gamma], [backprop])
|
||||
if scale_after_normalization:
|
||||
@ -687,16 +704,20 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
gamma, keep_dims_param_shape, name="keep_dims_gamma")
|
||||
epsilon = 0.001
|
||||
for scale_after_normalization in [True, False]:
|
||||
bn = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
keep_dims_bn = self._tfBatchNormV2(
|
||||
x, keep_dims_m, keep_dims_v, keep_dims_beta,
|
||||
keep_dims_gamma, epsilon, scale_after_normalization)
|
||||
tf_batch_norm, keep_dims_tf_batch_norm = sess.run([bn, keep_dims_bn])
|
||||
self.assertEquals(x_shape, tf_batch_norm.shape)
|
||||
self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
|
||||
self.assertAllClose(
|
||||
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
|
||||
for shift_after_normalization in [True, False]:
|
||||
bn = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
keep_dims_bn = self._tfBatchNormV2(
|
||||
x, keep_dims_m, keep_dims_v, keep_dims_beta,
|
||||
keep_dims_gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
|
||||
[bn, keep_dims_bn])
|
||||
self.assertEquals(x_shape, tf_batch_norm.shape)
|
||||
self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
|
||||
self.assertAllClose(
|
||||
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
|
||||
|
||||
def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.000001):
|
||||
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
||||
@ -713,15 +734,17 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
|
||||
gamma = tf.constant(gamma_val, name="gamma")
|
||||
epsilon = 0.001
|
||||
for scale_after_normalization in [True, False]:
|
||||
bn = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
||||
np_batch_norm = self._npBatchNorm(
|
||||
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
||||
scale_after_normalization)
|
||||
[tf_batch_norm] = sess.run([bn])
|
||||
self.assertEquals(x_shape, np_batch_norm.shape)
|
||||
self.assertEquals(x_shape, tf_batch_norm.shape)
|
||||
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
|
||||
for shift_after_normalization in [True, False]:
|
||||
bn = self._tfBatchNormV2(
|
||||
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
||||
shift_after_normalization)
|
||||
np_batch_norm = self._npBatchNorm(
|
||||
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
||||
scale_after_normalization, shift_after_normalization)
|
||||
[tf_batch_norm] = sess.run([bn])
|
||||
self.assertEquals(x_shape, np_batch_norm.shape)
|
||||
self.assertEquals(x_shape, tf_batch_norm.shape)
|
||||
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
|
||||
|
||||
def testBatchNormArbitraryShapes(self):
|
||||
"""Test for a variety of shapes and moments.
|
||||
|
Loading…
x
Reference in New Issue
Block a user