Enable fused batch norm, which is 15-20% faster for training and inference.
PiperOrigin-RevId: 168288154
This commit is contained in:
parent
08587d45b4
commit
3491881522
@ -22,7 +22,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||||
@ -98,8 +97,6 @@ DATA_FORMAT_NCHW = 'NCHW'
|
|||||||
DATA_FORMAT_NHWC = 'NHWC'
|
DATA_FORMAT_NHWC = 'NHWC'
|
||||||
DATA_FORMAT_NCDHW = 'NCDHW'
|
DATA_FORMAT_NCDHW = 'NCDHW'
|
||||||
DATA_FORMAT_NDHWC = 'NDHWC'
|
DATA_FORMAT_NDHWC = 'NDHWC'
|
||||||
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
|
|
||||||
'').lower() in ('true', 't', '1')
|
|
||||||
|
|
||||||
|
|
||||||
@add_arg_scope
|
@add_arg_scope
|
||||||
@ -549,10 +546,8 @@ def batch_norm(inputs,
|
|||||||
ValueError: If the rank of `inputs` is undefined.
|
ValueError: If the rank of `inputs` is undefined.
|
||||||
ValueError: If rank or channels dimension of `inputs` is undefined.
|
ValueError: If rank or channels dimension of `inputs` is undefined.
|
||||||
"""
|
"""
|
||||||
# This environment variable is only used during the testing period of fused
|
|
||||||
# batch norm and will be removed after that.
|
|
||||||
if fused is None:
|
if fused is None:
|
||||||
fused = _FUSED_DEFAULT
|
fused = True
|
||||||
|
|
||||||
# Only use _fused_batch_norm if all of the following three
|
# Only use _fused_batch_norm if all of the following three
|
||||||
# conditions are true:
|
# conditions are true:
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import six
|
import six
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -46,9 +45,6 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.layers import base
|
from tensorflow.python.layers import base
|
||||||
from tensorflow.python.layers import utils
|
from tensorflow.python.layers import utils
|
||||||
|
|
||||||
_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
|
|
||||||
'').lower() in ('true', 't', '1')
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNormalization(base.Layer):
|
class BatchNormalization(base.Layer):
|
||||||
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
|
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
|
||||||
@ -140,10 +136,8 @@ class BatchNormalization(base.Layer):
|
|||||||
self.beta_constraint = beta_constraint
|
self.beta_constraint = beta_constraint
|
||||||
self.gamma_constraint = gamma_constraint
|
self.gamma_constraint = gamma_constraint
|
||||||
self.renorm = renorm
|
self.renorm = renorm
|
||||||
# This environment variable is only used during the testing period of fused
|
|
||||||
# batch norm and will be removed after that.
|
|
||||||
if fused is None:
|
if fused is None:
|
||||||
fused = _FUSED_DEFAULT
|
fused = True
|
||||||
|
|
||||||
self.fused = fused
|
self.fused = fused
|
||||||
self._bessels_correction_test_only = True
|
self._bessels_correction_test_only = True
|
||||||
|
Loading…
Reference in New Issue
Block a user