Enable fused batch norm, which is 15-20% faster for training and inference.
PiperOrigin-RevId: 168288154
This commit is contained in:
parent
08587d45b4
commit
3491881522
@ -22,7 +22,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||
@ -98,8 +97,6 @@ DATA_FORMAT_NCHW = 'NCHW'
|
||||
DATA_FORMAT_NHWC = 'NHWC'
|
||||
DATA_FORMAT_NCDHW = 'NCDHW'
|
||||
DATA_FORMAT_NDHWC = 'NDHWC'
|
||||
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
|
||||
'').lower() in ('true', 't', '1')
|
||||
|
||||
|
||||
@add_arg_scope
|
||||
@ -549,10 +546,8 @@ def batch_norm(inputs,
|
||||
ValueError: If the rank of `inputs` is undefined.
|
||||
ValueError: If rank or channels dimension of `inputs` is undefined.
|
||||
"""
|
||||
# This environment variable is only used during the testing period of fused
|
||||
# batch norm and will be removed after that.
|
||||
if fused is None:
|
||||
fused = _FUSED_DEFAULT
|
||||
fused = True
|
||||
|
||||
# Only use _fused_batch_norm if all of the following three
|
||||
# conditions are true:
|
||||
|
@ -20,7 +20,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import numpy as np
|
||||
@ -46,9 +45,6 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.layers import base
|
||||
from tensorflow.python.layers import utils
|
||||
|
||||
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
|
||||
'').lower() in ('true', 't', '1')
|
||||
|
||||
|
||||
class BatchNormalization(base.Layer):
|
||||
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
|
||||
@ -140,10 +136,8 @@ class BatchNormalization(base.Layer):
|
||||
self.beta_constraint = beta_constraint
|
||||
self.gamma_constraint = gamma_constraint
|
||||
self.renorm = renorm
|
||||
# This environment variable is only used during the testing period of fused
|
||||
# batch norm and will be removed after that.
|
||||
if fused is None:
|
||||
fused = _FUSED_DEFAULT
|
||||
fused = True
|
||||
|
||||
self.fused = fused
|
||||
self._bessels_correction_test_only = True
|
||||
|
Loading…
Reference in New Issue
Block a user