Change inception v2 and v3 to use fused batchnorm

RELNOTES: Enable fused batchnorm on inception v2 and v3
PiperOrigin-RevId: 167904218
This commit is contained in:
Yunxing Dai 2017-09-07 13:38:57 -07:00 committed by TensorFlower Gardener
parent 40cb77d26e
commit 0575c60ac8

View File

@ -676,7 +676,8 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
def inception_v3_arg_scope(weight_decay=0.00004,
stddev=0.1,
batch_norm_var_collection='moving_vars'):
batch_norm_var_collection='moving_vars',
use_fused_batchnorm=True):
"""Defines the default InceptionV3 arg scope.
Args:
@ -684,6 +685,7 @@ def inception_v3_arg_scope(weight_decay=0.00004,
stddev: The standard deviation of the trunctated normal weight initializer.
batch_norm_var_collection: The name of the collection for the batch norm
variables.
use_fused_batchnorm: Enable fused batchnorm.
Returns:
An `arg_scope` to use for the inception v3 model.
@ -695,6 +697,8 @@ def inception_v3_arg_scope(weight_decay=0.00004,
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': ops.GraphKeys.UPDATE_OPS,
# Use fused batch norm if possible.
'fused': use_fused_batchnorm,
# collection containing the moving mean and moving variance.
'variables_collections': {
'beta': None,