diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py index f20c172ee37..4a905b1b2a0 100644 --- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py +++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py @@ -78,10 +78,11 @@ class DropStaleGradientOptimizer(optimizer.Optimizer): def apply_gradients(self, grads_and_vars, global_step=None, name=None): gradients = [] # Number of stale gradients. - stale_counter = variable_scope.get_variable( - "stale_counter", [], - initializer=init_ops.zeros_initializer(), - trainable=False) + with ops.colocate_with(global_step): + stale_counter = variable_scope.get_variable( + "stale_counter", [], + initializer=init_ops.zeros_initializer(), + trainable=False) def _AcceptGradientOp(): with ops.control_dependencies(