Change inception v2 and v3 to use fused batchnorm
RELNOTES: Enable fused batchnorm on inception v2 and v3 PiperOrigin-RevId: 167904218
This commit is contained in:
parent
40cb77d26e
commit
0575c60ac8
@ -676,7 +676,8 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
||||
|
||||
def inception_v3_arg_scope(weight_decay=0.00004,
|
||||
stddev=0.1,
|
||||
batch_norm_var_collection='moving_vars'):
|
||||
batch_norm_var_collection='moving_vars',
|
||||
use_fused_batchnorm=True):
|
||||
"""Defines the default InceptionV3 arg scope.
|
||||
|
||||
Args:
|
||||
@ -684,6 +685,7 @@ def inception_v3_arg_scope(weight_decay=0.00004,
|
||||
stddev: The standard deviation of the trunctated normal weight initializer.
|
||||
batch_norm_var_collection: The name of the collection for the batch norm
|
||||
variables.
|
||||
use_fused_batchnorm: Enable fused batchnorm.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the inception v3 model.
|
||||
@ -695,6 +697,8 @@ def inception_v3_arg_scope(weight_decay=0.00004,
|
||||
'epsilon': 0.001,
|
||||
# collection containing update_ops.
|
||||
'updates_collections': ops.GraphKeys.UPDATE_OPS,
|
||||
# Use fused batch norm if possible.
|
||||
'fused': use_fused_batchnorm,
|
||||
# collection containing the moving mean and moving variance.
|
||||
'variables_collections': {
|
||||
'beta': None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user