diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 3f4937b1e98..ecff8241a63 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -593,8 +593,8 @@ def batch_normalization(x, x: Input `Tensor` of arbitrary dimensionality. mean: A mean `Tensor`. variance: A variance `Tensor`. - offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, to - be applied to the normalized tensor. + offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, or + None. If present, will be added to the normalized tensor. scale: A scale `Tensor`, often denoted \\\\(\gamma\\\\) in equations, or `None`. If present, the scale is applied to the normalized tensor. variance_epsilon: A small float number to avoid dividing by 0. @@ -607,7 +607,7 @@ def batch_normalization(x, inv = math_ops.rsqrt(variance + variance_epsilon) if scale is not None: inv *= scale - return x * inv + (offset - mean * inv) + return x * inv + (offset - mean * inv if offset else -mean * inv) def batch_norm_with_global_normalization(t, diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index b908f9b8bd7..30c866e6a4c 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -479,19 +479,17 @@ class DropoutTest(tf.test.TestCase): class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, - scale_after_normalization): + scale_after_normalization, shift_after_normalization): y = (x - m) / np.sqrt(v + epsilon) y = y * gamma if scale_after_normalization else y - y += beta - return y + return y + beta if shift_after_normalization else y def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, - scale_after_normalization): + scale_after_normalization, shift_after_normalization): y = (x - m) * tf.rsqrt(v + epsilon) if scale_after_normalization: y = gamma * y - y += beta - return y + return y + beta if shift_after_normalization else y def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon, scale_after_normalization): @@ -510,10 +508,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): x, m, v, beta, gamma, epsilon, scale_after_normalization) def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon, - scale_after_normalization): + scale_after_normalization, shift_after_normalization): """New implementation.""" return tf.nn.batch_normalization( - x, m, v, beta, gamma if scale_after_normalization else None, epsilon) + x, m, v, beta if shift_after_normalization else None, + gamma if scale_after_normalization else None, epsilon) def testBatchNorm(self): x_shape = [3, 5, 4, 2] @@ -532,29 +531,35 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): gamma = tf.constant(gamma_val, name="gamma") epsilon = 0.001 for scale_after_normalization in [True, False]: - bn2 = self._tfBatchNormV2( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - bn1bw = self._tfBatchNormV1BW( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - bn1 = self._tfBatchNormV1( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - on = self._opsBatchNorm( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - np_bn = self._npBatchNorm( - x_val, m_val, v_val, beta_val, gamma_val, epsilon, - scale_after_normalization) - tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run( - [bn2, bn1bw, bn1, on]) - self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001) - self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001) - self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001) - self.assertAllClose(np_bn, ops_bn, atol=0.000001) - self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001) - self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001) - self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001) + for shift_after_normalization in [True, False]: + bn2 = self._tfBatchNormV2( + x, m, v, beta, gamma, epsilon, scale_after_normalization, + shift_after_normalization) + bn1bw = self._tfBatchNormV1BW( + x, m, v, beta, gamma, epsilon, scale_after_normalization) + bn1 = self._tfBatchNormV1( + x, m, v, beta, gamma, epsilon, scale_after_normalization) + on = self._opsBatchNorm( + x, m, v, beta, gamma, epsilon, scale_after_normalization, + shift_after_normalization) + np_bn = self._npBatchNorm( + x_val, m_val, v_val, beta_val, gamma_val, epsilon, + scale_after_normalization, shift_after_normalization) + tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run( + [bn2, bn1bw, bn1, on]) + self.assertAllClose(np_bn, ops_bn, atol=0.000001) + self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001) + self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001) + # shift_after_normalization=False is not supported in v1. + if shift_after_normalization: + self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001) + self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001) + self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001) + self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001) def _testBatchNormGradient(self, param_index, tag, scale_after_normalization, - version, err_tolerance=1e-11): + shift_after_normalization, version, + err_tolerance=1e-11): x_shape = [3, 5, 4, 5] param_shape = [5] np.random.seed(1) # Make it reproducible. @@ -575,7 +580,8 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): x, m, v, beta, gamma, epsilon, scale_after_normalization) elif version == 2: output = self._tfBatchNormV2( - x, m, v, beta, gamma, epsilon, scale_after_normalization) + x, m, v, beta, gamma, epsilon, scale_after_normalization, + shift_after_normalization) else: print("Invalid version", version) raise @@ -583,31 +589,39 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] err = tf.test.compute_gradient_error( all_params[param_index], all_shapes[param_index], output, x_shape) - print("Batch normalization v%d %s gradient %s scale err = " % - (version, tag, "with" if scale_after_normalization else "without"), + print("Batch normalization v%d %s gradient %s scale and %s shift err = " % + (version, tag, "with" if scale_after_normalization else "without", + "with" if shift_after_normalization else "without"), err) self.assertLess(err, err_tolerance) - def testBatchNormInputGradient(self): + def _testBatchNormGradientInAllNeedConfigs( + self, param_index, tag, err_tolerance=1e-11): for scale_after_normalization in [True, False]: - for v in [1, 2]: - self._testBatchNormGradient(0, "x", scale_after_normalization, v) + for shift_after_normalization in [True, False]: + # shift_after_normalization=False is not supported in version 1. + for v in ([1, 2] if shift_after_normalization else [2]): + self._testBatchNormGradient( + param_index, tag, scale_after_normalization, + shift_after_normalization, v, err_tolerance) + + def testBatchNormInputGradient(self): + self._testBatchNormGradientInAllNeedConfigs(0, "x") def testBatchNormMeanGradient(self): - for scale_after_normalization in [True, False]: - for v in [1, 2]: - self._testBatchNormGradient(1, "mean", scale_after_normalization, v) + self._testBatchNormGradientInAllNeedConfigs(1, "mean") def testBatchNormVarianceGradient(self): - for scale_after_normalization in [True, False]: - for v in [1, 2]: - self._testBatchNormGradient(2, "variance", scale_after_normalization, v, - err_tolerance=1e-03) + self._testBatchNormGradientInAllNeedConfigs(2, "variance", + err_tolerance=1e-03) def testBatchNormBetaGradient(self): + # Since beta does not exist when scale_after_normalization=False, we only + # test for scale_after_normalization=True. for scale_after_normalization in [True, False]: for v in [1, 2]: - self._testBatchNormGradient(3, "beta", scale_after_normalization, v) + self._testBatchNormGradient(3, "beta", scale_after_normalization, True, + v) def testBatchNormGammaGradient(self): # If scale_after_normalization is False, backprop for gamma in v1 @@ -615,8 +629,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): # gamma is not used at all, and the gradient is None, which displeases the # gradient checker. for scale_after_normalization in [True, False]: - self._testBatchNormGradient(4, "gamma", scale_after_normalization, 1) - self._testBatchNormGradient(4, "gamma", True, 2) + self._testBatchNormGradient(4, "gamma", scale_after_normalization, True, + 1) + for shift_after_normalization in [True, False]: + self._testBatchNormGradient(4, "gamma", True, shift_after_normalization, + 2) def testBatchNormGradImpl(self): x_shape = [7, 5, 4, 6] @@ -644,7 +661,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): gen_nn_ops._batch_norm_with_global_normalization_grad( x, m, v, gamma, backprop, epsilon, scale_after_normalization)) on = self._opsBatchNorm( - x, m, v, beta, gamma, epsilon, scale_after_normalization) + x, m, v, beta, gamma, epsilon, scale_after_normalization, True) odx, odm, odv, odb, odg = tf.gradients( [on], [x, m, v, beta, gamma], [backprop]) if scale_after_normalization: @@ -687,16 +704,20 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): gamma, keep_dims_param_shape, name="keep_dims_gamma") epsilon = 0.001 for scale_after_normalization in [True, False]: - bn = self._tfBatchNormV2( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - keep_dims_bn = self._tfBatchNormV2( - x, keep_dims_m, keep_dims_v, keep_dims_beta, - keep_dims_gamma, epsilon, scale_after_normalization) - tf_batch_norm, keep_dims_tf_batch_norm = sess.run([bn, keep_dims_bn]) - self.assertEquals(x_shape, tf_batch_norm.shape) - self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape) - self.assertAllClose( - tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) + for shift_after_normalization in [True, False]: + bn = self._tfBatchNormV2( + x, m, v, beta, gamma, epsilon, scale_after_normalization, + shift_after_normalization) + keep_dims_bn = self._tfBatchNormV2( + x, keep_dims_m, keep_dims_v, keep_dims_beta, + keep_dims_gamma, epsilon, scale_after_normalization, + shift_after_normalization) + tf_batch_norm, keep_dims_tf_batch_norm = sess.run( + [bn, keep_dims_bn]) + self.assertEquals(x_shape, tf_batch_norm.shape) + self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape) + self.assertAllClose( + tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.000001): x_val = np.random.random_sample(x_shape).astype(np.float32) @@ -713,15 +734,17 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): gamma = tf.constant(gamma_val, name="gamma") epsilon = 0.001 for scale_after_normalization in [True, False]: - bn = self._tfBatchNormV2( - x, m, v, beta, gamma, epsilon, scale_after_normalization) - np_batch_norm = self._npBatchNorm( - x_val, m_val, v_val, beta_val, gamma_val, epsilon, - scale_after_normalization) - [tf_batch_norm] = sess.run([bn]) - self.assertEquals(x_shape, np_batch_norm.shape) - self.assertEquals(x_shape, tf_batch_norm.shape) - self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol) + for shift_after_normalization in [True, False]: + bn = self._tfBatchNormV2( + x, m, v, beta, gamma, epsilon, scale_after_normalization, + shift_after_normalization) + np_batch_norm = self._npBatchNorm( + x_val, m_val, v_val, beta_val, gamma_val, epsilon, + scale_after_normalization, shift_after_normalization) + [tf_batch_norm] = sess.run([bn]) + self.assertEquals(x_shape, np_batch_norm.shape) + self.assertEquals(x_shape, tf_batch_norm.shape) + self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol) def testBatchNormArbitraryShapes(self): """Test for a variety of shapes and moments.