add support for nesterov momentum (#2798)
* add support for dense nesterov momentum * add sparse version and docs * clean work * cleanups * modify gpu functor * clean work
This commit is contained in:
parent
fc9162975e
commit
dadf364f85
@ -58,9 +58,8 @@ struct ApplyAdadelta<CPUDevice, T> {
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
accum.device(d) =
|
||||
accum * rho() + grad.square() * (static_cast<T>(1) - rho());
|
||||
const auto update =
|
||||
(accum_update + epsilon()).sqrt() *
|
||||
(accum + epsilon()).rsqrt() * grad;
|
||||
const auto update =
|
||||
(accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad;
|
||||
accum_update.device(d) =
|
||||
accum_update * rho() + update.square() * (static_cast<T>(1) - rho());
|
||||
var.device(d) -= update * lr();
|
||||
@ -176,9 +175,13 @@ struct ApplyMomentum<CPUDevice, T> {
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstFlat grad,
|
||||
typename TTypes<T>::ConstScalar momentum) {
|
||||
typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
|
||||
accum.device(d) = accum * momentum() + grad;
|
||||
var.device(d) -= accum * lr();
|
||||
if (use_nesterov) {
|
||||
var.device(d) -= grad * lr() + accum * momentum() * lr();
|
||||
} else {
|
||||
var.device(d) -= accum * lr();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -1515,6 +1518,7 @@ class ApplyMomentumOp : public OpKernel {
|
||||
public:
|
||||
explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -1554,12 +1558,13 @@ class ApplyMomentumOp : public OpKernel {
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||
lr.scalar<T>(), grad.flat<T>(),
|
||||
momentum.scalar<T>());
|
||||
momentum.scalar<T>(), use_nesterov_);
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
bool use_nesterov_;
|
||||
};
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -1584,7 +1589,7 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T>::Flat var, \
|
||||
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstFlat grad, \
|
||||
typename TTypes<T>::ConstScalar momentum); \
|
||||
typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \
|
||||
extern template struct ApplyMomentum<GPUDevice, T>;
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
@ -1605,6 +1610,7 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
@ -1672,7 +1678,12 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
auto g = grad_flat.template chip<0>(i);
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
a = a * a.constant(momentum_scalar) + g;
|
||||
v -= a.constant(lr_scalar) * a;
|
||||
if (use_nesterov_) {
|
||||
v -= g.constant(lr_scalar) * g +
|
||||
a.constant(lr_scalar) * a.constant(momentum_scalar) * a;
|
||||
} else {
|
||||
v -= a.constant(lr_scalar) * a;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1681,6 +1692,7 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
bool use_nesterov_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
||||
#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
@ -98,7 +98,7 @@ struct ApplyMomentum {
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstFlat grad,
|
||||
typename TTypes<T>::ConstScalar momentum);
|
||||
typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/training_ops.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -84,12 +84,18 @@ struct ApplyMomentum<GPUDevice, T> {
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstFlat grad,
|
||||
typename TTypes<T>::ConstScalar momentum) {
|
||||
typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
|
||||
if (use_nesterov) {
|
||||
var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
|
||||
accum * momentum.reshape(single).broadcast(bcast) *
|
||||
lr.reshape(single).broadcast(bcast);
|
||||
} else {
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -341,8 +341,10 @@ REGISTER_OP("ApplyMomentum")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("use_nesterov: bool = false")
|
||||
.Doc(R"doc(
|
||||
Update '*var' according to the momentum scheme.
|
||||
Update '*var' according to the momentum scheme. Set use_nesterov = True if you
|
||||
want to use Nesterov momentum.
|
||||
|
||||
accum = accum * momentum + grad
|
||||
var -= lr * accum
|
||||
@ -356,6 +358,9 @@ out: Same as "var".
|
||||
use_locking: If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
use_nesterov: If `True`, the tensor passed to compute grad will be
|
||||
var - lr * momentum * accum, so in the end, the var you get is actually
|
||||
var - lr * momentum * accum.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseApplyMomentum")
|
||||
@ -369,8 +374,10 @@ REGISTER_OP("SparseApplyMomentum")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("use_nesterov: bool = false")
|
||||
.Doc(R"doc(
|
||||
Update relevant entries in '*var' and '*accum' according to the momentum scheme.
|
||||
Set use_nesterov = True if you want to use Nesterov momentum.
|
||||
|
||||
That is for rows we have grad for, we update var and accum as follows:
|
||||
|
||||
@ -387,6 +394,9 @@ out: Same as "var".
|
||||
use_locking: If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
use_nesterov: If `True`, the tensor passed to compute grad will be
|
||||
var - lr * momentum * accum, so in the end, the var you get is actually
|
||||
var - lr * momentum * accum.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ApplyAdam")
|
||||
|
@ -31,7 +31,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, momentum,
|
||||
use_locking=False, name="Momentum"):
|
||||
use_locking=False, name="Momentum", use_nesterov=False):
|
||||
"""Construct a new Momentum optimizer.
|
||||
|
||||
Args:
|
||||
@ -44,6 +44,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
super(MomentumOptimizer, self).__init__(use_locking, name)
|
||||
self._learning_rate = learning_rate
|
||||
self._momentum = momentum
|
||||
self._use_nesterov = use_nesterov
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
for v in var_list:
|
||||
@ -62,7 +63,8 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||
grad,
|
||||
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
||||
use_locking=self._use_locking).op
|
||||
use_locking=self._use_locking,
|
||||
use_nesterov=self._use_nesterov).op
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
mom = self.get_slot(var, "momentum")
|
||||
@ -71,4 +73,5 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||
grad.values, grad.indices,
|
||||
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
||||
use_locking=self._use_locking).op
|
||||
use_locking=self._use_locking,
|
||||
use_nesterov=self._use_nesterov).op
|
||||
|
@ -25,6 +25,13 @@ import tensorflow as tf
|
||||
|
||||
class MomentumOptimizerTest(tf.test.TestCase):
|
||||
|
||||
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
|
||||
var = var + accum * lr * momentum
|
||||
accum = accum * momentum + g
|
||||
var = var - lr * accum
|
||||
var = var - accum * lr * momentum
|
||||
return var, accum
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in [tf.half, tf.float32, tf.float64]:
|
||||
with self.test_session():
|
||||
@ -80,6 +87,68 @@ class MomentumOptimizerTest(tf.test.TestCase):
|
||||
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
|
||||
var1.eval())
|
||||
|
||||
def testNesterovMomentum(self):
|
||||
for dtype in [tf.float32, tf.float64]:
|
||||
with self.test_session():
|
||||
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
cost = 5 * var0 * var0 + 3 * var1
|
||||
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
|
||||
mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9,
|
||||
use_nesterov=True)
|
||||
opt_op = mom_op.minimize(cost, global_step, [var0, var1])
|
||||
tf.initialize_all_variables().run()
|
||||
for t in range(1, 5):
|
||||
opt_op.run()
|
||||
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||
accum1_np, 3, 2.0, 0.9)
|
||||
self.assertAllClose(var0_np, var0.eval())
|
||||
self.assertAllClose(var1_np, var1.eval())
|
||||
|
||||
def testSparseNesterovMomentum(self):
|
||||
for dtype in [tf.float32, tf.float64]:
|
||||
with self.test_session():
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
grads = []
|
||||
for t in range(1, 5):
|
||||
grads.append(var0_np * 10)
|
||||
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||
accum1_np, 3, 2.0, 0.9)
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
var0 = tf.Variable(var0_np)
|
||||
var1 = tf.Variable(var1_np)
|
||||
loss = 5 * var0 * var0 + 3 * var1
|
||||
mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9,
|
||||
use_nesterov=True)
|
||||
x_feed = tf.placeholder(dtype)
|
||||
y_feed = tf.IndexedSlices(x_feed,tf.constant([0, 1]),tf.constant([2]))
|
||||
grads_and_vars = [(y_feed, var0),
|
||||
(tf.constant([3.0,3.0],dtype=dtype), var1)]
|
||||
opt_update = mom_op.apply_gradients(grads_and_vars)
|
||||
tf.initialize_all_variables().run()
|
||||
for t in range(1, 5):
|
||||
opt_update.run(feed_dict = {x_feed:grads[t - 1]})
|
||||
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||
accum1_np, 3, 2.0, 0.9)
|
||||
self.assertAllClose(var0_np, var0.eval())
|
||||
self.assertAllClose(var1_np, var1.eval())
|
||||
|
||||
def testTensorLearningRateAndMomentum(self):
|
||||
for dtype in [tf.half, tf.float32, tf.float64]:
|
||||
with self.test_session():
|
||||
|
Loading…
Reference in New Issue
Block a user