Change momentum optimizer to allow callable learning_rate and momentum
parameters. This can be useful for implementing learninge rate decay. PiperOrigin-RevId: 173975321
This commit is contained in:
parent
542b323e5a
commit
187453d61d
@ -28,7 +28,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Momentum algorithm.
|
||||
|
||||
Computes (if `use_nesterov = False`):
|
||||
|
||||
|
||||
```
|
||||
accumulation = momentum * accumulation + gradient
|
||||
variable -= learning_rate * accumulation
|
||||
@ -58,6 +58,12 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
|
||||
variable(s) track the values called `theta_t + mu*v_t` in the paper.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, learning_rate and momentum can each be a
|
||||
callable that takes no arguments and returns the actual value to use. This
|
||||
can be useful for changing these values across different invocations of
|
||||
optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(MomentumOptimizer, self).__init__(use_locking, name)
|
||||
self._learning_rate = learning_rate
|
||||
@ -69,10 +75,15 @@ class MomentumOptimizer(optimizer.Optimizer):
|
||||
self._zeros_slot(v, "momentum", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
|
||||
learning_rate = self._learning_rate
|
||||
if callable(learning_rate):
|
||||
learning_rate = learning_rate()
|
||||
self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
|
||||
name="learning_rate")
|
||||
self._momentum_tensor = ops.convert_to_tensor(self._momentum,
|
||||
name="momentum")
|
||||
momentum = self._momentum
|
||||
if callable(momentum):
|
||||
momentum = momentum()
|
||||
self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
mom = self.get_slot(var, "momentum")
|
||||
|
@ -44,7 +44,7 @@ class MomentumOptimizerTest(test.TestCase):
|
||||
var = var - accum * lr * momentum
|
||||
return var, accum
|
||||
|
||||
def doTestBasic(self, use_resource=False):
|
||||
def doTestBasic(self, use_resource=False, use_callable_params=False):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
@ -56,8 +56,13 @@ class MomentumOptimizerTest(test.TestCase):
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
learning_rate = lambda: 2.0
|
||||
momentum = lambda: 0.9
|
||||
if not use_callable_params:
|
||||
learning_rate = learning_rate()
|
||||
momentum = momentum()
|
||||
mom_opt = momentum_lib.MomentumOptimizer(
|
||||
learning_rate=2.0, momentum=0.9)
|
||||
learning_rate=learning_rate, momentum=momentum)
|
||||
mom_update = mom_opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
|
||||
@ -125,6 +130,10 @@ class MomentumOptimizerTest(test.TestCase):
|
||||
def testResourceBasic(self):
|
||||
self.doTestBasic(use_resource=True)
|
||||
|
||||
def testBasicCallableParams(self):
|
||||
with context.eager_mode():
|
||||
self.doTestBasic(use_resource=True, use_callable_params=True)
|
||||
|
||||
def testNesterovMomentum(self):
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
with self.test_session():
|
||||
|
Loading…
Reference in New Issue
Block a user