Change momentum optimizer to allow callable learning_rate and momentum

parameters. This can be useful for implementing learninge rate decay.

PiperOrigin-RevId: 173975321
This commit is contained in:
A. Unique TensorFlower 2017-10-30 17:27:08 -07:00 committed by TensorFlower Gardener
parent 542b323e5a
commit 187453d61d
2 changed files with 26 additions and 6 deletions

View File

@ -28,7 +28,7 @@ class MomentumOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
```
accumulation = momentum * accumulation + gradient
variable -= learning_rate * accumulation
@ -58,6 +58,12 @@ class MomentumOptimizer(optimizer.Optimizer):
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
variable(s) track the values called `theta_t + mu*v_t` in the paper.
@compatibility(eager)
When eager execution is enabled, learning_rate and momentum can each be a
callable that takes no arguments and returns the actual value to use. This
can be useful for changing these values across different invocations of
optimizer functions.
@end_compatibility
"""
super(MomentumOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
@ -69,10 +75,15 @@ class MomentumOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "momentum", self._name)
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
learning_rate = self._learning_rate
if callable(learning_rate):
learning_rate = learning_rate()
self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
name="learning_rate")
self._momentum_tensor = ops.convert_to_tensor(self._momentum,
name="momentum")
momentum = self._momentum
if callable(momentum):
momentum = momentum()
self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
def _apply_dense(self, grad, var):
mom = self.get_slot(var, "momentum")

View File

@ -44,7 +44,7 @@ class MomentumOptimizerTest(test.TestCase):
var = var - accum * lr * momentum
return var, accum
def doTestBasic(self, use_resource=False):
def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
if use_resource:
var0 = resource_variable_ops.ResourceVariable(
@ -56,8 +56,13 @@ class MomentumOptimizerTest(test.TestCase):
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
learning_rate = lambda: 2.0
momentum = lambda: 0.9
if not use_callable_params:
learning_rate = learning_rate()
momentum = momentum()
mom_opt = momentum_lib.MomentumOptimizer(
learning_rate=2.0, momentum=0.9)
learning_rate=learning_rate, momentum=momentum)
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
@ -125,6 +130,10 @@ class MomentumOptimizerTest(test.TestCase):
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
def testBasicCallableParams(self):
with context.eager_mode():
self.doTestBasic(use_resource=True, use_callable_params=True)
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.test_session():