diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 4e48f9c6fde..59fecea0c99 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -534,7 +534,7 @@ def train_step(sess, train_op, global_step, train_step_kwargs): if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): - logging.info('global step %d: loss = %.4f (%.2f sec)', + logging.info('global step %d: loss = %.4f (%.2f sec/step)', np_global_step, total_loss, time_elapsed) # TODO(nsilberman): figure out why we can't put this into sess.run. The diff --git a/tensorflow/contrib/slim/python/slim/nets/inception.py b/tensorflow/contrib/slim/python/slim/nets/inception.py index b6ec1b9da6c..6f50025644b 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception.py @@ -20,8 +20,10 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1 +from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_arg_scope from tensorflow.contrib.slim.python.slim.nets.inception_v1 import inception_v1_base from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2 +from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_arg_scope from tensorflow.contrib.slim.python.slim.nets.inception_v2 import inception_v2_base from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3 from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v1.py b/tensorflow/contrib/slim/python/slim/nets/inception_v1.py index 1e4ed212a86..8b9e3254a31 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v1.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v1.py @@ -299,3 +299,52 @@ def inception_v1(inputs, end_points['Predictions'] = prediction_fn(logits, scope='Predictions') return logits, end_points inception_v1.default_image_size = 224 + + +def inception_v1_arg_scope(weight_decay=0.00004, + use_batch_norm=True, + batch_norm_var_collection='moving_vars'): + """Defines the default InceptionV1 arg scope. + + Note: Althougth the original paper didn't use batch_norm we found it useful. + + Args: + weight_decay: The weight decay to use for regularizing the model. + use_batch_norm: "If `True`, batch_norm is applied after each convolution. + batch_norm_var_collection: The name of the collection for the batch norm + variables. + + Returns: + An `arg_scope` to use for the inception v3 model. + """ + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.9997, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # collection containing update_ops. + 'updates_collections': tf.GraphKeys.UPDATE_OPS, + # collection containing the moving mean and moving variance. + 'variables_collections': { + 'beta': None, + 'gamma': None, + 'moving_mean': [batch_norm_var_collection], + 'moving_variance': [batch_norm_var_collection], + } + } + if use_batch_norm: + normalizer_fn = slim.batch_norm + normalizer_params = batch_norm_params + else: + normalizer_fn = None + normalizer_params = {} + # Set weight_decay for weights in Conv and FC layers. + with slim.arg_scope([slim.conv2d, slim.fully_connected], + weights_regularizer=slim.l2_regularizer(weight_decay)): + with slim.arg_scope( + [slim.conv2d], + weights_initializer=slim.variance_scaling_initializer(), + activation_fn=tf.nn.relu, + normalizer_fn=normalizer_fn, + normalizer_params=normalizer_params) as sc: + return sc diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py index 3be79f792e3..aeae7cffe69 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py @@ -110,8 +110,7 @@ class InceptionV1Test(tf.test.TestCase): batch_size = 5 height, width = 224, 224 inputs = tf.random_uniform((batch_size, height, width, 3)) - with slim.arg_scope([slim.conv2d, slim.separable_conv2d], - normalizer_fn=slim.batch_norm): + with slim.arg_scope(inception.inception_v1_arg_scope()): inception.inception_v1_base(inputs) total_params, _ = slim.model_analyzer.analyze_vars( slim.get_model_variables()) diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v2.py b/tensorflow/contrib/slim/python/slim/nets/inception_v2.py index cf9372df3ae..2e8ddbd133a 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v2.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v2.py @@ -513,3 +513,43 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): kernel_size_out = [min(shape[1], kernel_size[0]), min(shape[2], kernel_size[1])] return kernel_size_out + + +def inception_v2_arg_scope(weight_decay=0.00004, + batch_norm_var_collection='moving_vars'): + """Defines the default InceptionV2 arg scope. + + Args: + weight_decay: The weight decay to use for regularizing the model. + batch_norm_var_collection: The name of the collection for the batch norm + variables. + + Returns: + An `arg_scope` to use for the inception v3 model. + """ + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.9997, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # collection containing update_ops. + 'updates_collections': tf.GraphKeys.UPDATE_OPS, + # collection containing the moving mean and moving variance. + 'variables_collections': { + 'beta': None, + 'gamma': None, + 'moving_mean': [batch_norm_var_collection], + 'moving_variance': [batch_norm_var_collection], + } + } + + # Set weight_decay for weights in Conv and FC layers. + with slim.arg_scope([slim.conv2d, slim.fully_connected], + weights_regularizer=slim.l2_regularizer(weight_decay)): + with slim.arg_scope( + [slim.conv2d], + weights_initializer=slim.variance_scaling_initializer(), + activation_fn=tf.nn.relu, + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params) as sc: + return sc diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py index 15ab355d64d..f9d26b9b4d3 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py @@ -107,12 +107,11 @@ class InceptionV2Test(tf.test.TestCase): batch_size = 5 height, width = 224, 224 inputs = tf.random_uniform((batch_size, height, width, 3)) - with slim.arg_scope([slim.conv2d, slim.separable_conv2d], - normalizer_fn=slim.batch_norm): + with slim.arg_scope(inception.inception_v2_arg_scope()): inception.inception_v2_base(inputs) total_params, _ = slim.model_analyzer.analyze_vars( slim.get_model_variables()) - self.assertAlmostEqual(10173240, total_params) + self.assertAlmostEqual(10173112, total_params) def testBuildEndPointsWithDepthMultiplierLessThanOne(self): batch_size = 5 diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v3.py b/tensorflow/contrib/slim/python/slim/nets/inception_v3.py index b938288ce46..0efd9013529 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v3.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v3.py @@ -555,14 +555,12 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): return kernel_size_out -def inception_v3_arg_scope(is_training=True, - weight_decay=0.00004, +def inception_v3_arg_scope(weight_decay=0.00004, stddev=0.1, batch_norm_var_collection='moving_vars'): """Defines the default InceptionV3 arg scope. Args: - is_training: Whether or not we're training the model. weight_decay: The weight decay to use for regularizing the model. stddev: The standard deviation of the trunctated normal weight initializer. batch_norm_var_collection: The name of the collection for the batch norm @@ -572,11 +570,12 @@ def inception_v3_arg_scope(is_training=True, An `arg_scope` to use for the inception v3 model. """ batch_norm_params = { - 'is_training': is_training, # Decay for the moving averages. 'decay': 0.9997, # epsilon to prevent 0s in variance. 'epsilon': 0.001, + # collection containing update_ops. + 'updates_collections': tf.GraphKeys.UPDATE_OPS, # collection containing the moving mean and moving variance. 'variables_collections': { 'beta': None, diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py index 786715f327e..ca978e30fec 100644 --- a/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py @@ -113,8 +113,7 @@ class InceptionV3Test(tf.test.TestCase): batch_size = 5 height, width = 299, 299 inputs = tf.random_uniform((batch_size, height, width, 3)) - with slim.arg_scope([slim.conv2d], - normalizer_fn=slim.batch_norm): + with slim.arg_scope(inception.inception_v3_arg_scope()): inception.inception_v3_base(inputs) total_params, _ = slim.model_analyzer.analyze_vars( slim.get_model_variables()) diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py index b375488cb04..de8c2effc21 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py @@ -206,7 +206,7 @@ def stack_blocks_dense(net, blocks, output_stride=None, return net -def resnet_arg_scope(is_training=False, +def resnet_arg_scope(is_training=True, weight_decay=0.0001, batch_norm_decay=0.997, batch_norm_epsilon=1e-5, @@ -236,7 +236,8 @@ def resnet_arg_scope(is_training=False, 'is_training': is_training, 'decay': batch_norm_decay, 'epsilon': batch_norm_epsilon, - 'scale': batch_norm_scale + 'scale': batch_norm_scale, + 'updates_collections': tf.GraphKeys.UPDATE_OPS, } with slim.arg_scope(