From 3491881522a4eafafea8acf8113f99468ead735c Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Mon, 11 Sep 2017 14:41:31 -0700 Subject: [PATCH] Enable fused batch norm, which is 15-20% faster for training and inference. PiperOrigin-RevId: 168288154 --- tensorflow/contrib/layers/python/layers/layers.py | 7 +------ tensorflow/python/layers/normalization.py | 8 +------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 36421f86587..33c31262664 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function import functools -import os import six from tensorflow.contrib.framework.python.ops import add_arg_scope @@ -98,8 +97,6 @@ DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' DATA_FORMAT_NCDHW = 'NCDHW' DATA_FORMAT_NDHWC = 'NDHWC' -_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM', - '').lower() in ('true', 't', '1') @add_arg_scope @@ -549,10 +546,8 @@ def batch_norm(inputs, ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ - # This environment variable is only used during the testing period of fused - # batch norm and will be removed after that. if fused is None: - fused = _FUSED_DEFAULT + fused = True # Only use _fused_batch_norm if all of the following three # conditions are true: diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 62265dce3c5..222817cd3a3 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -20,7 +20,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import six from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np @@ -46,9 +45,6 @@ from tensorflow.python.ops import variables from tensorflow.python.layers import base from tensorflow.python.layers import utils -_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM', - '').lower() in ('true', 't', '1') - class BatchNormalization(base.Layer): """Batch Normalization layer from http://arxiv.org/abs/1502.03167. @@ -140,10 +136,8 @@ class BatchNormalization(base.Layer): self.beta_constraint = beta_constraint self.gamma_constraint = gamma_constraint self.renorm = renorm - # This environment variable is only used during the testing period of fused - # batch norm and will be removed after that. if fused is None: - fused = _FUSED_DEFAULT + fused = True self.fused = fused self._bessels_correction_test_only = True