Add individual arg_scopes for each inception, and test them.
Change: 131226037
This commit is contained in:
parent
e207eff095
commit
399384b969
@ -534,7 +534,7 @@ def train_step(sess, train_op, global_step, train_step_kwargs):
|
|||||||
|
|
||||||
if 'should_log' in train_step_kwargs:
|
if 'should_log' in train_step_kwargs:
|
||||||
if sess.run(train_step_kwargs['should_log']):
|
if sess.run(train_step_kwargs['should_log']):
|
||||||
logging.info('global step %d: loss = %.4f (%.2f sec)',
|
logging.info('global step %d: loss = %.4f (%.2f sec/step)',
|
||||||
np_global_step, total_loss, time_elapsed)
|
np_global_step, total_loss, time_elapsed)
|
||||||
|
|
||||||
# TODO(nsilberman): figure out why we can't put this into sess.run. The
|
# TODO(nsilberman): figure out why we can't put this into sess.run. The
|
||||||
|
@ -20,8 +20,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1
|
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1
|
||||||
|
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_arg_scope
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_base
|
from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_base
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2
|
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2
|
||||||
|
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_arg_scope
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_base
|
from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_base
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
|
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
|
||||||
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope
|
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope
|
||||||
|
@ -299,3 +299,52 @@ def inception_v1(inputs,
|
|||||||
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
||||||
return logits, end_points
|
return logits, end_points
|
||||||
inception_v1.default_image_size = 224
|
inception_v1.default_image_size = 224
|
||||||
|
|
||||||
|
|
||||||
|
def inception_v1_arg_scope(weight_decay=0.00004,
|
||||||
|
use_batch_norm=True,
|
||||||
|
batch_norm_var_collection='moving_vars'):
|
||||||
|
"""Defines the default InceptionV1 arg scope.
|
||||||
|
|
||||||
|
Note: Althougth the original paper didn't use batch_norm we found it useful.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_decay: The weight decay to use for regularizing the model.
|
||||||
|
use_batch_norm: "If `True`, batch_norm is applied after each convolution.
|
||||||
|
batch_norm_var_collection: The name of the collection for the batch norm
|
||||||
|
variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `arg_scope` to use for the inception v3 model.
|
||||||
|
"""
|
||||||
|
batch_norm_params = {
|
||||||
|
# Decay for the moving averages.
|
||||||
|
'decay': 0.9997,
|
||||||
|
# epsilon to prevent 0s in variance.
|
||||||
|
'epsilon': 0.001,
|
||||||
|
# collection containing update_ops.
|
||||||
|
'updates_collections': tf.GraphKeys.UPDATE_OPS,
|
||||||
|
# collection containing the moving mean and moving variance.
|
||||||
|
'variables_collections': {
|
||||||
|
'beta': None,
|
||||||
|
'gamma': None,
|
||||||
|
'moving_mean': [batch_norm_var_collection],
|
||||||
|
'moving_variance': [batch_norm_var_collection],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if use_batch_norm:
|
||||||
|
normalizer_fn = slim.batch_norm
|
||||||
|
normalizer_params = batch_norm_params
|
||||||
|
else:
|
||||||
|
normalizer_fn = None
|
||||||
|
normalizer_params = {}
|
||||||
|
# Set weight_decay for weights in Conv and FC layers.
|
||||||
|
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||||
|
weights_regularizer=slim.l2_regularizer(weight_decay)):
|
||||||
|
with slim.arg_scope(
|
||||||
|
[slim.conv2d],
|
||||||
|
weights_initializer=slim.variance_scaling_initializer(),
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
normalizer_fn=normalizer_fn,
|
||||||
|
normalizer_params=normalizer_params) as sc:
|
||||||
|
return sc
|
||||||
|
@ -110,8 +110,7 @@ class InceptionV1Test(tf.test.TestCase):
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
height, width = 224, 224
|
height, width = 224, 224
|
||||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
|
with slim.arg_scope(inception.inception_v1_arg_scope()):
|
||||||
normalizer_fn=slim.batch_norm):
|
|
||||||
inception.inception_v1_base(inputs)
|
inception.inception_v1_base(inputs)
|
||||||
total_params, _ = slim.model_analyzer.analyze_vars(
|
total_params, _ = slim.model_analyzer.analyze_vars(
|
||||||
slim.get_model_variables())
|
slim.get_model_variables())
|
||||||
|
@ -513,3 +513,43 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
|||||||
kernel_size_out = [min(shape[1], kernel_size[0]),
|
kernel_size_out = [min(shape[1], kernel_size[0]),
|
||||||
min(shape[2], kernel_size[1])]
|
min(shape[2], kernel_size[1])]
|
||||||
return kernel_size_out
|
return kernel_size_out
|
||||||
|
|
||||||
|
|
||||||
|
def inception_v2_arg_scope(weight_decay=0.00004,
|
||||||
|
batch_norm_var_collection='moving_vars'):
|
||||||
|
"""Defines the default InceptionV2 arg scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_decay: The weight decay to use for regularizing the model.
|
||||||
|
batch_norm_var_collection: The name of the collection for the batch norm
|
||||||
|
variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `arg_scope` to use for the inception v3 model.
|
||||||
|
"""
|
||||||
|
batch_norm_params = {
|
||||||
|
# Decay for the moving averages.
|
||||||
|
'decay': 0.9997,
|
||||||
|
# epsilon to prevent 0s in variance.
|
||||||
|
'epsilon': 0.001,
|
||||||
|
# collection containing update_ops.
|
||||||
|
'updates_collections': tf.GraphKeys.UPDATE_OPS,
|
||||||
|
# collection containing the moving mean and moving variance.
|
||||||
|
'variables_collections': {
|
||||||
|
'beta': None,
|
||||||
|
'gamma': None,
|
||||||
|
'moving_mean': [batch_norm_var_collection],
|
||||||
|
'moving_variance': [batch_norm_var_collection],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set weight_decay for weights in Conv and FC layers.
|
||||||
|
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||||
|
weights_regularizer=slim.l2_regularizer(weight_decay)):
|
||||||
|
with slim.arg_scope(
|
||||||
|
[slim.conv2d],
|
||||||
|
weights_initializer=slim.variance_scaling_initializer(),
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
normalizer_fn=slim.batch_norm,
|
||||||
|
normalizer_params=batch_norm_params) as sc:
|
||||||
|
return sc
|
||||||
|
@ -107,12 +107,11 @@ class InceptionV2Test(tf.test.TestCase):
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
height, width = 224, 224
|
height, width = 224, 224
|
||||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
|
with slim.arg_scope(inception.inception_v2_arg_scope()):
|
||||||
normalizer_fn=slim.batch_norm):
|
|
||||||
inception.inception_v2_base(inputs)
|
inception.inception_v2_base(inputs)
|
||||||
total_params, _ = slim.model_analyzer.analyze_vars(
|
total_params, _ = slim.model_analyzer.analyze_vars(
|
||||||
slim.get_model_variables())
|
slim.get_model_variables())
|
||||||
self.assertAlmostEqual(10173240, total_params)
|
self.assertAlmostEqual(10173112, total_params)
|
||||||
|
|
||||||
def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
|
def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
@ -555,14 +555,12 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
|||||||
return kernel_size_out
|
return kernel_size_out
|
||||||
|
|
||||||
|
|
||||||
def inception_v3_arg_scope(is_training=True,
|
def inception_v3_arg_scope(weight_decay=0.00004,
|
||||||
weight_decay=0.00004,
|
|
||||||
stddev=0.1,
|
stddev=0.1,
|
||||||
batch_norm_var_collection='moving_vars'):
|
batch_norm_var_collection='moving_vars'):
|
||||||
"""Defines the default InceptionV3 arg scope.
|
"""Defines the default InceptionV3 arg scope.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
is_training: Whether or not we're training the model.
|
|
||||||
weight_decay: The weight decay to use for regularizing the model.
|
weight_decay: The weight decay to use for regularizing the model.
|
||||||
stddev: The standard deviation of the trunctated normal weight initializer.
|
stddev: The standard deviation of the trunctated normal weight initializer.
|
||||||
batch_norm_var_collection: The name of the collection for the batch norm
|
batch_norm_var_collection: The name of the collection for the batch norm
|
||||||
@ -572,11 +570,12 @@ def inception_v3_arg_scope(is_training=True,
|
|||||||
An `arg_scope` to use for the inception v3 model.
|
An `arg_scope` to use for the inception v3 model.
|
||||||
"""
|
"""
|
||||||
batch_norm_params = {
|
batch_norm_params = {
|
||||||
'is_training': is_training,
|
|
||||||
# Decay for the moving averages.
|
# Decay for the moving averages.
|
||||||
'decay': 0.9997,
|
'decay': 0.9997,
|
||||||
# epsilon to prevent 0s in variance.
|
# epsilon to prevent 0s in variance.
|
||||||
'epsilon': 0.001,
|
'epsilon': 0.001,
|
||||||
|
# collection containing update_ops.
|
||||||
|
'updates_collections': tf.GraphKeys.UPDATE_OPS,
|
||||||
# collection containing the moving mean and moving variance.
|
# collection containing the moving mean and moving variance.
|
||||||
'variables_collections': {
|
'variables_collections': {
|
||||||
'beta': None,
|
'beta': None,
|
||||||
|
@ -113,8 +113,7 @@ class InceptionV3Test(tf.test.TestCase):
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
height, width = 299, 299
|
height, width = 299, 299
|
||||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||||
with slim.arg_scope([slim.conv2d],
|
with slim.arg_scope(inception.inception_v3_arg_scope()):
|
||||||
normalizer_fn=slim.batch_norm):
|
|
||||||
inception.inception_v3_base(inputs)
|
inception.inception_v3_base(inputs)
|
||||||
total_params, _ = slim.model_analyzer.analyze_vars(
|
total_params, _ = slim.model_analyzer.analyze_vars(
|
||||||
slim.get_model_variables())
|
slim.get_model_variables())
|
||||||
|
@ -206,7 +206,7 @@ def stack_blocks_dense(net, blocks, output_stride=None,
|
|||||||
return net
|
return net
|
||||||
|
|
||||||
|
|
||||||
def resnet_arg_scope(is_training=False,
|
def resnet_arg_scope(is_training=True,
|
||||||
weight_decay=0.0001,
|
weight_decay=0.0001,
|
||||||
batch_norm_decay=0.997,
|
batch_norm_decay=0.997,
|
||||||
batch_norm_epsilon=1e-5,
|
batch_norm_epsilon=1e-5,
|
||||||
@ -236,7 +236,8 @@ def resnet_arg_scope(is_training=False,
|
|||||||
'is_training': is_training,
|
'is_training': is_training,
|
||||||
'decay': batch_norm_decay,
|
'decay': batch_norm_decay,
|
||||||
'epsilon': batch_norm_epsilon,
|
'epsilon': batch_norm_epsilon,
|
||||||
'scale': batch_norm_scale
|
'scale': batch_norm_scale,
|
||||||
|
'updates_collections': tf.GraphKeys.UPDATE_OPS,
|
||||||
}
|
}
|
||||||
|
|
||||||
with slim.arg_scope(
|
with slim.arg_scope(
|
||||||
|
Loading…
Reference in New Issue
Block a user