From dadf364f8592860dc42c89f9b307b0ab5f011319 Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Wed, 20 Jul 2016 03:12:38 +0800 Subject: [PATCH] add support for nesterov momentum (#2798) * add support for dense nesterov momentum * add sparse version and docs * clean work * cleanups * modify gpu functor * clean work --- tensorflow/core/kernels/training_ops.cc | 28 +++++--- tensorflow/core/kernels/training_ops.h | 4 +- .../core/kernels/training_ops_gpu.cu.cc | 12 +++- tensorflow/core/ops/training_ops.cc | 12 +++- tensorflow/python/training/momentum.py | 9 ++- tensorflow/python/training/momentum_test.py | 69 +++++++++++++++++++ 6 files changed, 117 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 2f9714a37af..867045eb1f0 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -58,9 +58,8 @@ struct ApplyAdadelta { typename TTypes::ConstFlat grad) { accum.device(d) = accum * rho() + grad.square() * (static_cast(1) - rho()); - const auto update = - (accum_update + epsilon()).sqrt() * - (accum + epsilon()).rsqrt() * grad; + const auto update = + (accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad; accum_update.device(d) = accum_update * rho() + update.square() * (static_cast(1) - rho()); var.device(d) -= update * lr(); @@ -176,9 +175,13 @@ struct ApplyMomentum { typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, - typename TTypes::ConstScalar momentum) { + typename TTypes::ConstScalar momentum, bool use_nesterov) { accum.device(d) = accum * momentum() + grad; - var.device(d) -= accum * lr(); + if (use_nesterov) { + var.device(d) -= grad * lr() + accum * momentum() * lr(); + } else { + var.device(d) -= accum * lr(); + } } }; @@ -1515,6 +1518,7 @@ class ApplyMomentumOp : public OpKernel { public: explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); } void Compute(OpKernelContext* ctx) override { @@ -1554,12 +1558,13 @@ class ApplyMomentumOp : public OpKernel { const Device& device = ctx->template eigen_device(); functor::ApplyMomentum()(device, var.flat(), accum.flat(), lr.scalar(), grad.flat(), - momentum.scalar()); + momentum.scalar(), use_nesterov_); ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; + bool use_nesterov_; }; typedef Eigen::ThreadPoolDevice CPUDevice; @@ -1584,7 +1589,7 @@ namespace functor { const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ typename TTypes::ConstFlat grad, \ - typename TTypes::ConstScalar momentum); \ + typename TTypes::ConstScalar momentum, bool use_nesterov); \ extern template struct ApplyMomentum; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); @@ -1605,6 +1610,7 @@ class SparseApplyMomentumOp : public OpKernel { public: explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { @@ -1672,7 +1678,12 @@ class SparseApplyMomentumOp : public OpKernel { auto g = grad_flat.template chip<0>(i); auto v = var_flat.template chip<0>(index); a = a * a.constant(momentum_scalar) + g; - v -= a.constant(lr_scalar) * a; + if (use_nesterov_) { + v -= g.constant(lr_scalar) * g + + a.constant(lr_scalar) * a.constant(momentum_scalar) * a; + } else { + v -= a.constant(lr_scalar) * a; + } } } @@ -1681,6 +1692,7 @@ class SparseApplyMomentumOp : public OpKernel { private: bool use_exclusive_lock_; + bool use_nesterov_; }; #define REGISTER_KERNELS(T, Tindices) \ diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index b9946cd9228..017cae6c7c0 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_ #define TENSORFLOW_KERNELS_TRAINING_OPS_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { @@ -98,7 +98,7 @@ struct ApplyMomentum { typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, - typename TTypes::ConstScalar momentum); + typename TTypes::ConstScalar momentum, bool use_nesterov); }; template diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index ab56880cfb5..589e70e76d1 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/training_ops.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -84,12 +84,18 @@ struct ApplyMomentum { typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, - typename TTypes::ConstScalar momentum) { + typename TTypes::ConstScalar momentum, bool use_nesterov) { Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad; - var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + if (use_nesterov) { + var.device(d) -= grad * lr.reshape(single).broadcast(bcast) + + accum * momentum.reshape(single).broadcast(bcast) * + lr.reshape(single).broadcast(bcast); + } else { + var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + } } }; diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 756a7707f27..4c2b67ea95d 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -341,8 +341,10 @@ REGISTER_OP("ApplyMomentum") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .Doc(R"doc( -Update '*var' according to the momentum scheme. +Update '*var' according to the momentum scheme. Set use_nesterov = True if you +want to use Nesterov momentum. accum = accum * momentum + grad var -= lr * accum @@ -356,6 +358,9 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, the tensor passed to compute grad will be +var - lr * momentum * accum, so in the end, the var you get is actually +var - lr * momentum * accum. )doc"); REGISTER_OP("SparseApplyMomentum") @@ -369,8 +374,10 @@ REGISTER_OP("SparseApplyMomentum") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .Doc(R"doc( Update relevant entries in '*var' and '*accum' according to the momentum scheme. +Set use_nesterov = True if you want to use Nesterov momentum. That is for rows we have grad for, we update var and accum as follows: @@ -387,6 +394,9 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, the tensor passed to compute grad will be +var - lr * momentum * accum, so in the end, the var you get is actually +var - lr * momentum * accum. )doc"); REGISTER_OP("ApplyAdam") diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index 1586ddfdec2..62f8028ce68 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -31,7 +31,7 @@ class MomentumOptimizer(optimizer.Optimizer): """ def __init__(self, learning_rate, momentum, - use_locking=False, name="Momentum"): + use_locking=False, name="Momentum", use_nesterov=False): """Construct a new Momentum optimizer. Args: @@ -44,6 +44,7 @@ class MomentumOptimizer(optimizer.Optimizer): super(MomentumOptimizer, self).__init__(use_locking, name) self._learning_rate = learning_rate self._momentum = momentum + self._use_nesterov = use_nesterov def _create_slots(self, var_list): for v in var_list: @@ -62,7 +63,8 @@ class MomentumOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), grad, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), - use_locking=self._use_locking).op + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op def _apply_sparse(self, grad, var): mom = self.get_slot(var, "momentum") @@ -71,4 +73,5 @@ class MomentumOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), grad.values, grad.indices, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), - use_locking=self._use_locking).op + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 3807f9e8d34..a1cbf9bfb59 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -25,6 +25,13 @@ import tensorflow as tf class MomentumOptimizerTest(tf.test.TestCase): + def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): + var = var + accum * lr * momentum + accum = accum * momentum + g + var = var - lr * accum + var = var - accum * lr * momentum + return var, accum + def testBasic(self): for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): @@ -80,6 +87,68 @@ class MomentumOptimizerTest(tf.test.TestCase): 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]), var1.eval()) + def testNesterovMomentum(self): + for dtype in [tf.float32, tf.float64]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + cost = 5 * var0 * var0 + 3 * var1 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9, + use_nesterov=True) + opt_op = mom_op.minimize(cost, global_step, [var0, var1]) + tf.initialize_all_variables().run() + for t in range(1, 5): + opt_op.run() + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + + def testSparseNesterovMomentum(self): + for dtype in [tf.float32, tf.float64]: + with self.test_session(): + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + grads = [] + for t in range(1, 5): + grads.append(var0_np * 10) + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + loss = 5 * var0 * var0 + 3 * var1 + mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9, + use_nesterov=True) + x_feed = tf.placeholder(dtype) + y_feed = tf.IndexedSlices(x_feed,tf.constant([0, 1]),tf.constant([2])) + grads_and_vars = [(y_feed, var0), + (tf.constant([3.0,3.0],dtype=dtype), var1)] + opt_update = mom_op.apply_gradients(grads_and_vars) + tf.initialize_all_variables().run() + for t in range(1, 5): + opt_update.run(feed_dict = {x_feed:grads[t - 1]}) + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + def testTensorLearningRateAndMomentum(self): for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session():