Add individual arg_scopes for each inception, and test them.

Change: 131226037
This commit is contained in:
Sergio Guadarrama 2016-08-24 14:52:55 -08:00 committed by TensorFlower Gardener
parent e207eff095
commit 399384b969
9 changed files with 102 additions and 14 deletions

View File

@ -534,7 +534,7 @@ def train_step(sess, train_op, global_step, train_step_kwargs):
if 'should_log' in train_step_kwargs: if 'should_log' in train_step_kwargs:
if sess.run(train_step_kwargs['should_log']): if sess.run(train_step_kwargs['should_log']):
logging.info('global step %d: loss = %.4f (%.2f sec)', logging.info('global step %d: loss = %.4f (%.2f sec/step)',
np_global_step, total_loss, time_elapsed) np_global_step, total_loss, time_elapsed)
# TODO(nsilberman): figure out why we can't put this into sess.run. The # TODO(nsilberman): figure out why we can't put this into sess.run. The

View File

@ -20,8 +20,10 @@ from __future__ import print_function
# pylint: disable=unused-import # pylint: disable=unused-import
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1 from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_arg_scope
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_base from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_base
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2 from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_arg_scope
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_base from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_base
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3 from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope

View File

@ -299,3 +299,52 @@ def inception_v1(inputs,
end_points['Predictions'] = prediction_fn(logits, scope='Predictions') end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
return logits, end_points return logits, end_points
inception_v1.default_image_size = 224 inception_v1.default_image_size = 224
def inception_v1_arg_scope(weight_decay=0.00004,
use_batch_norm=True,
batch_norm_var_collection='moving_vars'):
"""Defines the default InceptionV1 arg scope.
Note: Althougth the original paper didn't use batch_norm we found it useful.
Args:
weight_decay: The weight decay to use for regularizing the model.
use_batch_norm: "If `True`, batch_norm is applied after each convolution.
batch_norm_var_collection: The name of the collection for the batch norm
variables.
Returns:
An `arg_scope` to use for the inception v3 model.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': 0.9997,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
# collection containing the moving mean and moving variance.
'variables_collections': {
'beta': None,
'gamma': None,
'moving_mean': [batch_norm_var_collection],
'moving_variance': [batch_norm_var_collection],
}
}
if use_batch_norm:
normalizer_fn = slim.batch_norm
normalizer_params = batch_norm_params
else:
normalizer_fn = None
normalizer_params = {}
# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params) as sc:
return sc

View File

@ -110,8 +110,7 @@ class InceptionV1Test(tf.test.TestCase):
batch_size = 5 batch_size = 5
height, width = 224, 224 height, width = 224, 224
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], with slim.arg_scope(inception.inception_v1_arg_scope()):
normalizer_fn=slim.batch_norm):
inception.inception_v1_base(inputs) inception.inception_v1_base(inputs)
total_params, _ = slim.model_analyzer.analyze_vars( total_params, _ = slim.model_analyzer.analyze_vars(
slim.get_model_variables()) slim.get_model_variables())

View File

@ -513,3 +513,43 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
kernel_size_out = [min(shape[1], kernel_size[0]), kernel_size_out = [min(shape[1], kernel_size[0]),
min(shape[2], kernel_size[1])] min(shape[2], kernel_size[1])]
return kernel_size_out return kernel_size_out
def inception_v2_arg_scope(weight_decay=0.00004,
batch_norm_var_collection='moving_vars'):
"""Defines the default InceptionV2 arg scope.
Args:
weight_decay: The weight decay to use for regularizing the model.
batch_norm_var_collection: The name of the collection for the batch norm
variables.
Returns:
An `arg_scope` to use for the inception v3 model.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': 0.9997,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
# collection containing the moving mean and moving variance.
'variables_collections': {
'beta': None,
'gamma': None,
'moving_mean': [batch_norm_var_collection],
'moving_variance': [batch_norm_var_collection],
}
}
# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc

View File

@ -107,12 +107,11 @@ class InceptionV2Test(tf.test.TestCase):
batch_size = 5 batch_size = 5
height, width = 224, 224 height, width = 224, 224
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], with slim.arg_scope(inception.inception_v2_arg_scope()):
normalizer_fn=slim.batch_norm):
inception.inception_v2_base(inputs) inception.inception_v2_base(inputs)
total_params, _ = slim.model_analyzer.analyze_vars( total_params, _ = slim.model_analyzer.analyze_vars(
slim.get_model_variables()) slim.get_model_variables())
self.assertAlmostEqual(10173240, total_params) self.assertAlmostEqual(10173112, total_params)
def testBuildEndPointsWithDepthMultiplierLessThanOne(self): def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
batch_size = 5 batch_size = 5

View File

@ -555,14 +555,12 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
return kernel_size_out return kernel_size_out
def inception_v3_arg_scope(is_training=True, def inception_v3_arg_scope(weight_decay=0.00004,
weight_decay=0.00004,
stddev=0.1, stddev=0.1,
batch_norm_var_collection='moving_vars'): batch_norm_var_collection='moving_vars'):
"""Defines the default InceptionV3 arg scope. """Defines the default InceptionV3 arg scope.
Args: Args:
is_training: Whether or not we're training the model.
weight_decay: The weight decay to use for regularizing the model. weight_decay: The weight decay to use for regularizing the model.
stddev: The standard deviation of the trunctated normal weight initializer. stddev: The standard deviation of the trunctated normal weight initializer.
batch_norm_var_collection: The name of the collection for the batch norm batch_norm_var_collection: The name of the collection for the batch norm
@ -572,11 +570,12 @@ def inception_v3_arg_scope(is_training=True,
An `arg_scope` to use for the inception v3 model. An `arg_scope` to use for the inception v3 model.
""" """
batch_norm_params = { batch_norm_params = {
'is_training': is_training,
# Decay for the moving averages. # Decay for the moving averages.
'decay': 0.9997, 'decay': 0.9997,
# epsilon to prevent 0s in variance. # epsilon to prevent 0s in variance.
'epsilon': 0.001, 'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
# collection containing the moving mean and moving variance. # collection containing the moving mean and moving variance.
'variables_collections': { 'variables_collections': {
'beta': None, 'beta': None,

View File

@ -113,8 +113,7 @@ class InceptionV3Test(tf.test.TestCase):
batch_size = 5 batch_size = 5
height, width = 299, 299 height, width = 299, 299
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
with slim.arg_scope([slim.conv2d], with slim.arg_scope(inception.inception_v3_arg_scope()):
normalizer_fn=slim.batch_norm):
inception.inception_v3_base(inputs) inception.inception_v3_base(inputs)
total_params, _ = slim.model_analyzer.analyze_vars( total_params, _ = slim.model_analyzer.analyze_vars(
slim.get_model_variables()) slim.get_model_variables())

View File

@ -206,7 +206,7 @@ def stack_blocks_dense(net, blocks, output_stride=None,
return net return net
def resnet_arg_scope(is_training=False, def resnet_arg_scope(is_training=True,
weight_decay=0.0001, weight_decay=0.0001,
batch_norm_decay=0.997, batch_norm_decay=0.997,
batch_norm_epsilon=1e-5, batch_norm_epsilon=1e-5,
@ -236,7 +236,8 @@ def resnet_arg_scope(is_training=False,
'is_training': is_training, 'is_training': is_training,
'decay': batch_norm_decay, 'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale 'scale': batch_norm_scale,
'updates_collections': tf.GraphKeys.UPDATE_OPS,
} }
with slim.arg_scope( with slim.arg_scope(