add support for nesterov momentum (#2798)
* add support for dense nesterov momentum * add sparse version and docs * clean work * cleanups * modify gpu functor * clean work
This commit is contained in:
parent
fc9162975e
commit
dadf364f85
@ -59,8 +59,7 @@ struct ApplyAdadelta<CPUDevice, T> {
|
|||||||
accum.device(d) =
|
accum.device(d) =
|
||||||
accum * rho() + grad.square() * (static_cast<T>(1) - rho());
|
accum * rho() + grad.square() * (static_cast<T>(1) - rho());
|
||||||
const auto update =
|
const auto update =
|
||||||
(accum_update + epsilon()).sqrt() *
|
(accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad;
|
||||||
(accum + epsilon()).rsqrt() * grad;
|
|
||||||
accum_update.device(d) =
|
accum_update.device(d) =
|
||||||
accum_update * rho() + update.square() * (static_cast<T>(1) - rho());
|
accum_update * rho() + update.square() * (static_cast<T>(1) - rho());
|
||||||
var.device(d) -= update * lr();
|
var.device(d) -= update * lr();
|
||||||
@ -176,9 +175,13 @@ struct ApplyMomentum<CPUDevice, T> {
|
|||||||
typename TTypes<T>::Flat accum,
|
typename TTypes<T>::Flat accum,
|
||||||
typename TTypes<T>::ConstScalar lr,
|
typename TTypes<T>::ConstScalar lr,
|
||||||
typename TTypes<T>::ConstFlat grad,
|
typename TTypes<T>::ConstFlat grad,
|
||||||
typename TTypes<T>::ConstScalar momentum) {
|
typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
|
||||||
accum.device(d) = accum * momentum() + grad;
|
accum.device(d) = accum * momentum() + grad;
|
||||||
var.device(d) -= accum * lr();
|
if (use_nesterov) {
|
||||||
|
var.device(d) -= grad * lr() + accum * momentum() * lr();
|
||||||
|
} else {
|
||||||
|
var.device(d) -= accum * lr();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1515,6 +1518,7 @@ class ApplyMomentumOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
@ -1554,12 +1558,13 @@ class ApplyMomentumOp : public OpKernel {
|
|||||||
const Device& device = ctx->template eigen_device<Device>();
|
const Device& device = ctx->template eigen_device<Device>();
|
||||||
functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||||
lr.scalar<T>(), grad.flat<T>(),
|
lr.scalar<T>(), grad.flat<T>(),
|
||||||
momentum.scalar<T>());
|
momentum.scalar<T>(), use_nesterov_);
|
||||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool use_exclusive_lock_;
|
bool use_exclusive_lock_;
|
||||||
|
bool use_nesterov_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
@ -1584,7 +1589,7 @@ namespace functor {
|
|||||||
const GPUDevice& d, typename TTypes<T>::Flat var, \
|
const GPUDevice& d, typename TTypes<T>::Flat var, \
|
||||||
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
|
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
|
||||||
typename TTypes<T>::ConstFlat grad, \
|
typename TTypes<T>::ConstFlat grad, \
|
||||||
typename TTypes<T>::ConstScalar momentum); \
|
typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \
|
||||||
extern template struct ApplyMomentum<GPUDevice, T>;
|
extern template struct ApplyMomentum<GPUDevice, T>;
|
||||||
DECLARE_GPU_SPEC(Eigen::half);
|
DECLARE_GPU_SPEC(Eigen::half);
|
||||||
DECLARE_GPU_SPEC(float);
|
DECLARE_GPU_SPEC(float);
|
||||||
@ -1605,6 +1610,7 @@ class SparseApplyMomentumOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
@ -1672,7 +1678,12 @@ class SparseApplyMomentumOp : public OpKernel {
|
|||||||
auto g = grad_flat.template chip<0>(i);
|
auto g = grad_flat.template chip<0>(i);
|
||||||
auto v = var_flat.template chip<0>(index);
|
auto v = var_flat.template chip<0>(index);
|
||||||
a = a * a.constant(momentum_scalar) + g;
|
a = a * a.constant(momentum_scalar) + g;
|
||||||
v -= a.constant(lr_scalar) * a;
|
if (use_nesterov_) {
|
||||||
|
v -= g.constant(lr_scalar) * g +
|
||||||
|
a.constant(lr_scalar) * a.constant(momentum_scalar) * a;
|
||||||
|
} else {
|
||||||
|
v -= a.constant(lr_scalar) * a;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1681,6 +1692,7 @@ class SparseApplyMomentumOp : public OpKernel {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool use_exclusive_lock_;
|
bool use_exclusive_lock_;
|
||||||
|
bool use_nesterov_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_KERNELS(T, Tindices) \
|
#define REGISTER_KERNELS(T, Tindices) \
|
||||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
||||||
#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -98,7 +98,7 @@ struct ApplyMomentum {
|
|||||||
typename TTypes<T>::Flat accum,
|
typename TTypes<T>::Flat accum,
|
||||||
typename TTypes<T>::ConstScalar lr,
|
typename TTypes<T>::ConstScalar lr,
|
||||||
typename TTypes<T>::ConstFlat grad,
|
typename TTypes<T>::ConstFlat grad,
|
||||||
typename TTypes<T>::ConstScalar momentum);
|
typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
|
||||||
#include "tensorflow/core/kernels/training_ops.h"
|
#include "tensorflow/core/kernels/training_ops.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -84,12 +84,18 @@ struct ApplyMomentum<GPUDevice, T> {
|
|||||||
typename TTypes<T>::Flat accum,
|
typename TTypes<T>::Flat accum,
|
||||||
typename TTypes<T>::ConstScalar lr,
|
typename TTypes<T>::ConstScalar lr,
|
||||||
typename TTypes<T>::ConstFlat grad,
|
typename TTypes<T>::ConstFlat grad,
|
||||||
typename TTypes<T>::ConstScalar momentum) {
|
typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
|
||||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||||
bcast[0] = grad.dimension(0);
|
bcast[0] = grad.dimension(0);
|
||||||
Eigen::Sizes<1> single;
|
Eigen::Sizes<1> single;
|
||||||
accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
|
accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
|
||||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
|
if (use_nesterov) {
|
||||||
|
var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
|
||||||
|
accum * momentum.reshape(single).broadcast(bcast) *
|
||||||
|
lr.reshape(single).broadcast(bcast);
|
||||||
|
} else {
|
||||||
|
var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -341,8 +341,10 @@ REGISTER_OP("ApplyMomentum")
|
|||||||
.Output("out: Ref(T)")
|
.Output("out: Ref(T)")
|
||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("use_locking: bool = false")
|
.Attr("use_locking: bool = false")
|
||||||
|
.Attr("use_nesterov: bool = false")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Update '*var' according to the momentum scheme.
|
Update '*var' according to the momentum scheme. Set use_nesterov = True if you
|
||||||
|
want to use Nesterov momentum.
|
||||||
|
|
||||||
accum = accum * momentum + grad
|
accum = accum * momentum + grad
|
||||||
var -= lr * accum
|
var -= lr * accum
|
||||||
@ -356,6 +358,9 @@ out: Same as "var".
|
|||||||
use_locking: If `True`, updating of the var and accum tensors will be protected
|
use_locking: If `True`, updating of the var and accum tensors will be protected
|
||||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||||
contention.
|
contention.
|
||||||
|
use_nesterov: If `True`, the tensor passed to compute grad will be
|
||||||
|
var - lr * momentum * accum, so in the end, the var you get is actually
|
||||||
|
var - lr * momentum * accum.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("SparseApplyMomentum")
|
REGISTER_OP("SparseApplyMomentum")
|
||||||
@ -369,8 +374,10 @@ REGISTER_OP("SparseApplyMomentum")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("Tindices: {int32, int64}")
|
.Attr("Tindices: {int32, int64}")
|
||||||
.Attr("use_locking: bool = false")
|
.Attr("use_locking: bool = false")
|
||||||
|
.Attr("use_nesterov: bool = false")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Update relevant entries in '*var' and '*accum' according to the momentum scheme.
|
Update relevant entries in '*var' and '*accum' according to the momentum scheme.
|
||||||
|
Set use_nesterov = True if you want to use Nesterov momentum.
|
||||||
|
|
||||||
That is for rows we have grad for, we update var and accum as follows:
|
That is for rows we have grad for, we update var and accum as follows:
|
||||||
|
|
||||||
@ -387,6 +394,9 @@ out: Same as "var".
|
|||||||
use_locking: If `True`, updating of the var and accum tensors will be protected
|
use_locking: If `True`, updating of the var and accum tensors will be protected
|
||||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||||
contention.
|
contention.
|
||||||
|
use_nesterov: If `True`, the tensor passed to compute grad will be
|
||||||
|
var - lr * momentum * accum, so in the end, the var you get is actually
|
||||||
|
var - lr * momentum * accum.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("ApplyAdam")
|
REGISTER_OP("ApplyAdam")
|
||||||
|
@ -31,7 +31,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate, momentum,
|
def __init__(self, learning_rate, momentum,
|
||||||
use_locking=False, name="Momentum"):
|
use_locking=False, name="Momentum", use_nesterov=False):
|
||||||
"""Construct a new Momentum optimizer.
|
"""Construct a new Momentum optimizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -44,6 +44,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
super(MomentumOptimizer, self).__init__(use_locking, name)
|
super(MomentumOptimizer, self).__init__(use_locking, name)
|
||||||
self._learning_rate = learning_rate
|
self._learning_rate = learning_rate
|
||||||
self._momentum = momentum
|
self._momentum = momentum
|
||||||
|
self._use_nesterov = use_nesterov
|
||||||
|
|
||||||
def _create_slots(self, var_list):
|
def _create_slots(self, var_list):
|
||||||
for v in var_list:
|
for v in var_list:
|
||||||
@ -62,7 +63,8 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||||
grad,
|
grad,
|
||||||
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
||||||
use_locking=self._use_locking).op
|
use_locking=self._use_locking,
|
||||||
|
use_nesterov=self._use_nesterov).op
|
||||||
|
|
||||||
def _apply_sparse(self, grad, var):
|
def _apply_sparse(self, grad, var):
|
||||||
mom = self.get_slot(var, "momentum")
|
mom = self.get_slot(var, "momentum")
|
||||||
@ -71,4 +73,5 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||||
grad.values, grad.indices,
|
grad.values, grad.indices,
|
||||||
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
||||||
use_locking=self._use_locking).op
|
use_locking=self._use_locking,
|
||||||
|
use_nesterov=self._use_nesterov).op
|
||||||
|
@ -25,6 +25,13 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class MomentumOptimizerTest(tf.test.TestCase):
|
class MomentumOptimizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
|
||||||
|
var = var + accum * lr * momentum
|
||||||
|
accum = accum * momentum + g
|
||||||
|
var = var - lr * accum
|
||||||
|
var = var - accum * lr * momentum
|
||||||
|
return var, accum
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
for dtype in [tf.half, tf.float32, tf.float64]:
|
for dtype in [tf.half, tf.float32, tf.float64]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -80,6 +87,68 @@ class MomentumOptimizerTest(tf.test.TestCase):
|
|||||||
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
|
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
|
||||||
var1.eval())
|
var1.eval())
|
||||||
|
|
||||||
|
def testNesterovMomentum(self):
|
||||||
|
for dtype in [tf.float32, tf.float64]:
|
||||||
|
with self.test_session():
|
||||||
|
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
|
||||||
|
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
cost = 5 * var0 * var0 + 3 * var1
|
||||||
|
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
|
||||||
|
mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9,
|
||||||
|
use_nesterov=True)
|
||||||
|
opt_op = mom_op.minimize(cost, global_step, [var0, var1])
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for t in range(1, 5):
|
||||||
|
opt_op.run()
|
||||||
|
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||||
|
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||||
|
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||||
|
accum1_np, 3, 2.0, 0.9)
|
||||||
|
self.assertAllClose(var0_np, var0.eval())
|
||||||
|
self.assertAllClose(var1_np, var1.eval())
|
||||||
|
|
||||||
|
def testSparseNesterovMomentum(self):
|
||||||
|
for dtype in [tf.float32, tf.float64]:
|
||||||
|
with self.test_session():
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads = []
|
||||||
|
for t in range(1, 5):
|
||||||
|
grads.append(var0_np * 10)
|
||||||
|
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||||
|
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||||
|
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||||
|
accum1_np, 3, 2.0, 0.9)
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
var0 = tf.Variable(var0_np)
|
||||||
|
var1 = tf.Variable(var1_np)
|
||||||
|
loss = 5 * var0 * var0 + 3 * var1
|
||||||
|
mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9,
|
||||||
|
use_nesterov=True)
|
||||||
|
x_feed = tf.placeholder(dtype)
|
||||||
|
y_feed = tf.IndexedSlices(x_feed,tf.constant([0, 1]),tf.constant([2]))
|
||||||
|
grads_and_vars = [(y_feed, var0),
|
||||||
|
(tf.constant([3.0,3.0],dtype=dtype), var1)]
|
||||||
|
opt_update = mom_op.apply_gradients(grads_and_vars)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for t in range(1, 5):
|
||||||
|
opt_update.run(feed_dict = {x_feed:grads[t - 1]})
|
||||||
|
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
|
||||||
|
accum0_np, var0_np * 10, 2.0, 0.9)
|
||||||
|
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
|
||||||
|
accum1_np, 3, 2.0, 0.9)
|
||||||
|
self.assertAllClose(var0_np, var0.eval())
|
||||||
|
self.assertAllClose(var1_np, var1.eval())
|
||||||
|
|
||||||
def testTensorLearningRateAndMomentum(self):
|
def testTensorLearningRateAndMomentum(self):
|
||||||
for dtype in [tf.half, tf.float32, tf.float64]:
|
for dtype in [tf.half, tf.float32, tf.float64]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
Loading…
Reference in New Issue
Block a user