Support leaving the offset (beta) parameter out in batch_normalization, in which case no offset will be added after normalization.
Change: 115489328
This commit is contained in:
		
							parent
							
								
									87a289103f
								
							
						
					
					
						commit
						4afef14f02
					
				@ -593,8 +593,8 @@ def batch_normalization(x,
 | 
			
		||||
    x: Input `Tensor` of arbitrary dimensionality.
 | 
			
		||||
    mean: A mean `Tensor`.
 | 
			
		||||
    variance: A variance `Tensor`.
 | 
			
		||||
    offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, to
 | 
			
		||||
      be applied to the normalized tensor.
 | 
			
		||||
    offset: An offset `Tensor`, often denoted \\\\(\beta\\\\) in equations, or
 | 
			
		||||
      None. If present, will be added to the normalized tensor.
 | 
			
		||||
    scale: A scale `Tensor`, often denoted \\\\(\gamma\\\\) in equations, or
 | 
			
		||||
      `None`. If present, the scale is applied to the normalized tensor.
 | 
			
		||||
    variance_epsilon: A small float number to avoid dividing by 0.
 | 
			
		||||
@ -607,7 +607,7 @@ def batch_normalization(x,
 | 
			
		||||
    inv = math_ops.rsqrt(variance + variance_epsilon)
 | 
			
		||||
    if scale is not None:
 | 
			
		||||
      inv *= scale
 | 
			
		||||
    return x * inv + (offset - mean * inv)
 | 
			
		||||
    return x * inv + (offset - mean * inv if offset else -mean * inv)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def batch_norm_with_global_normalization(t,
 | 
			
		||||
 | 
			
		||||
@ -479,19 +479,17 @@ class DropoutTest(tf.test.TestCase):
 | 
			
		||||
class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
 | 
			
		||||
                   scale_after_normalization):
 | 
			
		||||
                   scale_after_normalization, shift_after_normalization):
 | 
			
		||||
    y = (x - m) / np.sqrt(v + epsilon)
 | 
			
		||||
    y = y * gamma if scale_after_normalization else y
 | 
			
		||||
    y += beta
 | 
			
		||||
    return y
 | 
			
		||||
    return y + beta if shift_after_normalization else y
 | 
			
		||||
 | 
			
		||||
  def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
 | 
			
		||||
                    scale_after_normalization):
 | 
			
		||||
                    scale_after_normalization, shift_after_normalization):
 | 
			
		||||
    y = (x - m) * tf.rsqrt(v + epsilon)
 | 
			
		||||
    if scale_after_normalization:
 | 
			
		||||
      y = gamma * y
 | 
			
		||||
    y += beta
 | 
			
		||||
    return y
 | 
			
		||||
    return y + beta if shift_after_normalization else y
 | 
			
		||||
 | 
			
		||||
  def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
 | 
			
		||||
                     scale_after_normalization):
 | 
			
		||||
@ -510,10 +508,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
        x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
 | 
			
		||||
  def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
 | 
			
		||||
                     scale_after_normalization):
 | 
			
		||||
                     scale_after_normalization, shift_after_normalization):
 | 
			
		||||
    """New implementation."""
 | 
			
		||||
    return tf.nn.batch_normalization(
 | 
			
		||||
        x, m, v, beta, gamma if scale_after_normalization else None, epsilon)
 | 
			
		||||
        x, m, v, beta if shift_after_normalization else None,
 | 
			
		||||
        gamma if scale_after_normalization else None, epsilon)
 | 
			
		||||
 | 
			
		||||
  def testBatchNorm(self):
 | 
			
		||||
    x_shape = [3, 5, 4, 2]
 | 
			
		||||
@ -532,29 +531,35 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
        gamma = tf.constant(gamma_val, name="gamma")
 | 
			
		||||
        epsilon = 0.001
 | 
			
		||||
        for scale_after_normalization in [True, False]:
 | 
			
		||||
          for shift_after_normalization in [True, False]:
 | 
			
		||||
            bn2 = self._tfBatchNormV2(
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization,
 | 
			
		||||
                shift_after_normalization)
 | 
			
		||||
            bn1bw = self._tfBatchNormV1BW(
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
            bn1 = self._tfBatchNormV1(
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
            on = self._opsBatchNorm(
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization,
 | 
			
		||||
                shift_after_normalization)
 | 
			
		||||
            np_bn = self._npBatchNorm(
 | 
			
		||||
                x_val, m_val, v_val, beta_val, gamma_val, epsilon,
 | 
			
		||||
              scale_after_normalization)
 | 
			
		||||
                scale_after_normalization, shift_after_normalization)
 | 
			
		||||
            tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
 | 
			
		||||
                [bn2, bn1bw, bn1, on])
 | 
			
		||||
            self.assertAllClose(np_bn, ops_bn, atol=0.000001)
 | 
			
		||||
            self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
 | 
			
		||||
            self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
 | 
			
		||||
            # shift_after_normalization=False is not supported in v1.
 | 
			
		||||
            if shift_after_normalization:
 | 
			
		||||
              self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
 | 
			
		||||
              self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
 | 
			
		||||
          self.assertAllClose(np_bn, ops_bn, atol=0.000001)
 | 
			
		||||
              self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
 | 
			
		||||
              self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
 | 
			
		||||
          self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
 | 
			
		||||
 | 
			
		||||
  def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
 | 
			
		||||
                             version, err_tolerance=1e-11):
 | 
			
		||||
                             shift_after_normalization, version,
 | 
			
		||||
                             err_tolerance=1e-11):
 | 
			
		||||
    x_shape = [3, 5, 4, 5]
 | 
			
		||||
    param_shape = [5]
 | 
			
		||||
    np.random.seed(1)  # Make it reproducible.
 | 
			
		||||
@ -575,7 +580,8 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
            x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
      elif version == 2:
 | 
			
		||||
        output = self._tfBatchNormV2(
 | 
			
		||||
            x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
            x, m, v, beta, gamma, epsilon, scale_after_normalization,
 | 
			
		||||
            shift_after_normalization)
 | 
			
		||||
      else:
 | 
			
		||||
        print("Invalid version", version)
 | 
			
		||||
        raise
 | 
			
		||||
@ -583,31 +589,39 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
      all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
 | 
			
		||||
      err = tf.test.compute_gradient_error(
 | 
			
		||||
          all_params[param_index], all_shapes[param_index], output, x_shape)
 | 
			
		||||
    print("Batch normalization v%d %s gradient %s scale err = " %
 | 
			
		||||
          (version, tag, "with" if scale_after_normalization else "without"),
 | 
			
		||||
    print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
 | 
			
		||||
          (version, tag, "with" if scale_after_normalization else "without",
 | 
			
		||||
           "with" if shift_after_normalization else "without"),
 | 
			
		||||
          err)
 | 
			
		||||
    self.assertLess(err, err_tolerance)
 | 
			
		||||
 | 
			
		||||
  def testBatchNormInputGradient(self):
 | 
			
		||||
  def _testBatchNormGradientInAllNeedConfigs(
 | 
			
		||||
      self, param_index, tag, err_tolerance=1e-11):
 | 
			
		||||
    for scale_after_normalization in [True, False]:
 | 
			
		||||
      for v in [1, 2]:
 | 
			
		||||
        self._testBatchNormGradient(0, "x", scale_after_normalization, v)
 | 
			
		||||
      for shift_after_normalization in [True, False]:
 | 
			
		||||
        # shift_after_normalization=False is not supported in version 1.
 | 
			
		||||
        for v in ([1, 2] if shift_after_normalization else [2]):
 | 
			
		||||
          self._testBatchNormGradient(
 | 
			
		||||
              param_index, tag, scale_after_normalization,
 | 
			
		||||
              shift_after_normalization, v, err_tolerance)
 | 
			
		||||
 | 
			
		||||
  def testBatchNormInputGradient(self):
 | 
			
		||||
    self._testBatchNormGradientInAllNeedConfigs(0, "x")
 | 
			
		||||
 | 
			
		||||
  def testBatchNormMeanGradient(self):
 | 
			
		||||
    for scale_after_normalization in [True, False]:
 | 
			
		||||
      for v in [1, 2]:
 | 
			
		||||
        self._testBatchNormGradient(1, "mean", scale_after_normalization, v)
 | 
			
		||||
    self._testBatchNormGradientInAllNeedConfigs(1, "mean")
 | 
			
		||||
 | 
			
		||||
  def testBatchNormVarianceGradient(self):
 | 
			
		||||
    for scale_after_normalization in [True, False]:
 | 
			
		||||
      for v in [1, 2]:
 | 
			
		||||
        self._testBatchNormGradient(2, "variance", scale_after_normalization, v,
 | 
			
		||||
    self._testBatchNormGradientInAllNeedConfigs(2, "variance",
 | 
			
		||||
                                                err_tolerance=1e-03)
 | 
			
		||||
 | 
			
		||||
  def testBatchNormBetaGradient(self):
 | 
			
		||||
    # Since beta does not exist when scale_after_normalization=False, we only
 | 
			
		||||
    # test for scale_after_normalization=True.
 | 
			
		||||
    for scale_after_normalization in [True, False]:
 | 
			
		||||
      for v in [1, 2]:
 | 
			
		||||
        self._testBatchNormGradient(3, "beta", scale_after_normalization, v)
 | 
			
		||||
        self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
 | 
			
		||||
                                    v)
 | 
			
		||||
 | 
			
		||||
  def testBatchNormGammaGradient(self):
 | 
			
		||||
    # If scale_after_normalization is False, backprop for gamma in v1
 | 
			
		||||
@ -615,8 +629,11 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
    # gamma is not used at all, and the gradient is None, which displeases the
 | 
			
		||||
    # gradient checker.
 | 
			
		||||
    for scale_after_normalization in [True, False]:
 | 
			
		||||
      self._testBatchNormGradient(4, "gamma", scale_after_normalization, 1)
 | 
			
		||||
    self._testBatchNormGradient(4, "gamma", True, 2)
 | 
			
		||||
      self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
 | 
			
		||||
                                  1)
 | 
			
		||||
    for shift_after_normalization in [True, False]:
 | 
			
		||||
      self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
 | 
			
		||||
                                  2)
 | 
			
		||||
 | 
			
		||||
  def testBatchNormGradImpl(self):
 | 
			
		||||
    x_shape = [7, 5, 4, 6]
 | 
			
		||||
@ -644,7 +661,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
              gen_nn_ops._batch_norm_with_global_normalization_grad(
 | 
			
		||||
              x, m, v, gamma, backprop, epsilon, scale_after_normalization))
 | 
			
		||||
          on = self._opsBatchNorm(
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization, True)
 | 
			
		||||
          odx, odm, odv, odb, odg = tf.gradients(
 | 
			
		||||
              [on], [x, m, v, beta, gamma], [backprop])
 | 
			
		||||
          if scale_after_normalization:
 | 
			
		||||
@ -687,12 +704,16 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
            gamma, keep_dims_param_shape, name="keep_dims_gamma")
 | 
			
		||||
        epsilon = 0.001
 | 
			
		||||
        for scale_after_normalization in [True, False]:
 | 
			
		||||
          for shift_after_normalization in [True, False]:
 | 
			
		||||
            bn = self._tfBatchNormV2(
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization,
 | 
			
		||||
                shift_after_normalization)
 | 
			
		||||
            keep_dims_bn = self._tfBatchNormV2(
 | 
			
		||||
                x, keep_dims_m, keep_dims_v, keep_dims_beta,
 | 
			
		||||
              keep_dims_gamma, epsilon, scale_after_normalization)
 | 
			
		||||
          tf_batch_norm, keep_dims_tf_batch_norm = sess.run([bn, keep_dims_bn])
 | 
			
		||||
                keep_dims_gamma, epsilon, scale_after_normalization,
 | 
			
		||||
                shift_after_normalization)
 | 
			
		||||
            tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
 | 
			
		||||
                [bn, keep_dims_bn])
 | 
			
		||||
            self.assertEquals(x_shape, tf_batch_norm.shape)
 | 
			
		||||
            self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
 | 
			
		||||
            self.assertAllClose(
 | 
			
		||||
@ -713,11 +734,13 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
 | 
			
		||||
        gamma = tf.constant(gamma_val, name="gamma")
 | 
			
		||||
        epsilon = 0.001
 | 
			
		||||
        for scale_after_normalization in [True, False]:
 | 
			
		||||
          for shift_after_normalization in [True, False]:
 | 
			
		||||
            bn = self._tfBatchNormV2(
 | 
			
		||||
              x, m, v, beta, gamma, epsilon, scale_after_normalization)
 | 
			
		||||
                x, m, v, beta, gamma, epsilon, scale_after_normalization,
 | 
			
		||||
                shift_after_normalization)
 | 
			
		||||
            np_batch_norm = self._npBatchNorm(
 | 
			
		||||
                x_val, m_val, v_val, beta_val, gamma_val, epsilon,
 | 
			
		||||
              scale_after_normalization)
 | 
			
		||||
                scale_after_normalization, shift_after_normalization)
 | 
			
		||||
            [tf_batch_norm] = sess.run([bn])
 | 
			
		||||
            self.assertEquals(x_shape, np_batch_norm.shape)
 | 
			
		||||
            self.assertEquals(x_shape, tf_batch_norm.shape)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user