Change momentum optimizer to allow callable learning_rate and momentum
parameters. This can be useful for implementing learninge rate decay. PiperOrigin-RevId: 173975321
This commit is contained in:
parent
542b323e5a
commit
187453d61d
@ -28,7 +28,7 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
"""Optimizer that implements the Momentum algorithm.
|
"""Optimizer that implements the Momentum algorithm.
|
||||||
|
|
||||||
Computes (if `use_nesterov = False`):
|
Computes (if `use_nesterov = False`):
|
||||||
|
|
||||||
```
|
```
|
||||||
accumulation = momentum * accumulation + gradient
|
accumulation = momentum * accumulation + gradient
|
||||||
variable -= learning_rate * accumulation
|
variable -= learning_rate * accumulation
|
||||||
@ -58,6 +58,12 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
|
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
|
||||||
variable(s) track the values called `theta_t + mu*v_t` in the paper.
|
variable(s) track the values called `theta_t + mu*v_t` in the paper.
|
||||||
|
|
||||||
|
@compatibility(eager)
|
||||||
|
When eager execution is enabled, learning_rate and momentum can each be a
|
||||||
|
callable that takes no arguments and returns the actual value to use. This
|
||||||
|
can be useful for changing these values across different invocations of
|
||||||
|
optimizer functions.
|
||||||
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
super(MomentumOptimizer, self).__init__(use_locking, name)
|
super(MomentumOptimizer, self).__init__(use_locking, name)
|
||||||
self._learning_rate = learning_rate
|
self._learning_rate = learning_rate
|
||||||
@ -69,10 +75,15 @@ class MomentumOptimizer(optimizer.Optimizer):
|
|||||||
self._zeros_slot(v, "momentum", self._name)
|
self._zeros_slot(v, "momentum", self._name)
|
||||||
|
|
||||||
def _prepare(self):
|
def _prepare(self):
|
||||||
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
|
learning_rate = self._learning_rate
|
||||||
|
if callable(learning_rate):
|
||||||
|
learning_rate = learning_rate()
|
||||||
|
self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
|
||||||
name="learning_rate")
|
name="learning_rate")
|
||||||
self._momentum_tensor = ops.convert_to_tensor(self._momentum,
|
momentum = self._momentum
|
||||||
name="momentum")
|
if callable(momentum):
|
||||||
|
momentum = momentum()
|
||||||
|
self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
|
||||||
|
|
||||||
def _apply_dense(self, grad, var):
|
def _apply_dense(self, grad, var):
|
||||||
mom = self.get_slot(var, "momentum")
|
mom = self.get_slot(var, "momentum")
|
||||||
|
@ -44,7 +44,7 @@ class MomentumOptimizerTest(test.TestCase):
|
|||||||
var = var - accum * lr * momentum
|
var = var - accum * lr * momentum
|
||||||
return var, accum
|
return var, accum
|
||||||
|
|
||||||
def doTestBasic(self, use_resource=False):
|
def doTestBasic(self, use_resource=False, use_callable_params=False):
|
||||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
if use_resource:
|
if use_resource:
|
||||||
var0 = resource_variable_ops.ResourceVariable(
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
@ -56,8 +56,13 @@ class MomentumOptimizerTest(test.TestCase):
|
|||||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||||
|
learning_rate = lambda: 2.0
|
||||||
|
momentum = lambda: 0.9
|
||||||
|
if not use_callable_params:
|
||||||
|
learning_rate = learning_rate()
|
||||||
|
momentum = momentum()
|
||||||
mom_opt = momentum_lib.MomentumOptimizer(
|
mom_opt = momentum_lib.MomentumOptimizer(
|
||||||
learning_rate=2.0, momentum=0.9)
|
learning_rate=learning_rate, momentum=momentum)
|
||||||
mom_update = mom_opt.apply_gradients(
|
mom_update = mom_opt.apply_gradients(
|
||||||
zip([grads0, grads1], [var0, var1]))
|
zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
@ -125,6 +130,10 @@ class MomentumOptimizerTest(test.TestCase):
|
|||||||
def testResourceBasic(self):
|
def testResourceBasic(self):
|
||||||
self.doTestBasic(use_resource=True)
|
self.doTestBasic(use_resource=True)
|
||||||
|
|
||||||
|
def testBasicCallableParams(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
self.doTestBasic(use_resource=True, use_callable_params=True)
|
||||||
|
|
||||||
def testNesterovMomentum(self):
|
def testNesterovMomentum(self):
|
||||||
for dtype in [dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.float32, dtypes.float64]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user