Enable fused batch norm, which is 15-20% faster for training and inference.

PiperOrigin-RevId: 168288154
This commit is contained in:
Yao Zhang 2017-09-11 14:41:31 -07:00 committed by TensorFlower Gardener
parent 08587d45b4
commit 3491881522
2 changed files with 2 additions and 13 deletions

View File

@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
import functools
import os
import six
from tensorflow.contrib.framework.python.ops import add_arg_scope
@ -98,8 +97,6 @@ DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
DATA_FORMAT_NCDHW = 'NCDHW'
DATA_FORMAT_NDHWC = 'NDHWC'
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
'').lower() in ('true', 't', '1')
@add_arg_scope
@ -549,10 +546,8 @@ def batch_norm(inputs,
ValueError: If the rank of `inputs` is undefined.
ValueError: If rank or channels dimension of `inputs` is undefined.
"""
# This environment variable is only used during the testing period of fused
# batch norm and will be removed after that.
if fused is None:
fused = _FUSED_DEFAULT
fused = True
# Only use _fused_batch_norm if all of the following three
# conditions are true:

View File

@ -20,7 +20,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
@ -46,9 +45,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
'').lower() in ('true', 't', '1')
class BatchNormalization(base.Layer):
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@ -140,10 +136,8 @@ class BatchNormalization(base.Layer):
self.beta_constraint = beta_constraint
self.gamma_constraint = gamma_constraint
self.renorm = renorm
# This environment variable is only used during the testing period of fused
# batch norm and will be removed after that.
if fused is None:
fused = _FUSED_DEFAULT
fused = True
self.fused = fused
self._bessels_correction_test_only = True