Merge pull request #26993 from chie8842:change_rmsprop_doc

PiperOrigin-RevId: 242221932
This commit is contained in:
TensorFlower Gardener 2019-04-05 17:26:55 -07:00
commit 22b32624a2

View File

@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Built-in optimizer classes. """Built-in optimizer classes."""
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -112,28 +111,26 @@ class Optimizer(object):
(otherwise the optimizer has no weights). (otherwise the optimizer has no weights).
Arguments: Arguments:
weights: a list of Numpy arrays. The number weights: a list of Numpy arrays. The number of arrays and their shape
of arrays and their shape must match must match number of the dimensions of the weights of the optimizer
number of the dimensions of the weights (i.e. it should match the output of `get_weights`).
of the optimizer (i.e. it should match the
output of `get_weights`).
Raises: Raises:
ValueError: in case of incompatible weight shapes. ValueError: in case of incompatible weight shapes.
""" """
params = self.weights params = self.weights
if len(params) != len(weights): if len(params) != len(weights):
raise ValueError( raise ValueError('Length of the specified weight list (' +
'Length of the specified weight list (' + str(len(weights)) + str(len(weights)) +
') does not match the number of weights ' ') does not match the number of weights '
'of the optimizer (' + str(len(params)) + ')') 'of the optimizer (' + str(len(params)) + ')')
weight_value_tuples = [] weight_value_tuples = []
param_values = K.batch_get_value(params) param_values = K.batch_get_value(params)
for pv, p, w in zip(param_values, params, weights): for pv, p, w in zip(param_values, params, weights):
if pv.shape != w.shape: if pv.shape != w.shape:
raise ValueError( raise ValueError('Optimizer weight shape ' + str(pv.shape) +
'Optimizer weight shape ' + str(pv.shape) + ' not compatible with ' ' not compatible with '
'provided weight shape ' + str(w.shape)) 'provided weight shape ' + str(w.shape))
weight_value_tuples.append((p, w)) weight_value_tuples.append((p, w))
K.batch_set_value(weight_value_tuples) K.batch_set_value(weight_value_tuples)
@ -166,8 +163,8 @@ class SGD(Optimizer):
Arguments: Arguments:
lr: float >= 0. Learning rate. lr: float >= 0. Learning rate.
momentum: float >= 0. Parameter that accelerates SGD momentum: float >= 0. Parameter that accelerates SGD in the relevant
in the relevant direction and dampens oscillations. direction and dampens oscillations.
decay: float >= 0. Learning rate decay over each update. decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum. nesterov: boolean. Whether to apply Nesterov momentum.
""" """
@ -189,8 +186,9 @@ class SGD(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
# momentum # momentum
shapes = [K.int_shape(p) for p in params] shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes] moments = [K.zeros(shape) for shape in shapes]
@ -229,15 +227,11 @@ class RMSprop(Optimizer):
at their default values at their default values
(except the learning rate, which can be freely tuned). (except the learning rate, which can be freely tuned).
This optimizer is usually a good choice for recurrent
neural networks.
Arguments: Arguments:
lr: float >= 0. Learning rate. lr: float >= 0. Learning rate.
rho: float >= 0. rho: float >= 0.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update. decay: float >= 0. Learning rate decay over each update.
""" """
def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs): def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
@ -261,8 +255,9 @@ class RMSprop(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators): for p, g, a in zip(params, grads, accumulators):
# update accumulator # update accumulator
@ -305,7 +300,8 @@ class Adagrad(Optimizer):
decay: float >= 0. Learning rate decay over each update. decay: float >= 0. Learning rate decay over each update.
# References # References
- [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - [Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
""" """
def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs): def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
@ -329,8 +325,9 @@ class Adagrad(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators): for p, g, a in zip(params, grads, accumulators):
new_a = a + math_ops.square(g) # update accumulator new_a = a + math_ops.square(g) # update accumulator
@ -377,7 +374,8 @@ class Adadelta(Optimizer):
decay: float >= 0. Initial learning rate decay. decay: float >= 0. Initial learning rate decay.
# References # References
- [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701) - [Adadelta - an adaptive learning rate
method](http://arxiv.org/abs/1212.5701)
""" """
def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs): def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
@ -403,8 +401,9 @@ class Adadelta(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
# update accumulator # update accumulator
@ -448,10 +447,8 @@ class Adam(Optimizer):
beta_2: float, 0 < beta < 1. Generally close to 1. beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update. decay: float >= 0. Learning rate decay over each update.
amsgrad: boolean. Whether to apply the AMSGrad variant of this amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm
algorithm from the paper "On the Convergence of Adam and from the paper "On the Convergence of Adam and Beyond".
Beyond".
""" """
def __init__(self, def __init__(self,
@ -482,8 +479,9 @@ class Adam(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
t = math_ops.cast(self.iterations, K.floatx()) t = math_ops.cast(self.iterations, K.floatx())
@ -544,7 +542,6 @@ class Adamax(Optimizer):
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update. decay: float >= 0. Learning rate decay over each update.
""" """
def __init__(self, def __init__(self,
@ -573,8 +570,9 @@ class Adamax(Optimizer):
lr = self.lr lr = self.lr
if self.initial_decay > 0: if self.initial_decay > 0:
lr = lr * ( # pylint: disable=g-no-augmented-assignment lr = lr * ( # pylint: disable=g-no-augmented-assignment
1. / (1. + self.decay * math_ops.cast(self.iterations, 1. /
K.dtype(self.decay)))) (1. +
self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
t = math_ops.cast(self.iterations, K.floatx()) t = math_ops.cast(self.iterations, K.floatx())
@ -630,7 +628,6 @@ class Nadam(Optimizer):
lr: float >= 0. Learning rate. lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
""" """
def __init__(self, def __init__(self,
@ -683,8 +680,8 @@ class Nadam(Optimizer):
m_t_prime = m_t / (1. - m_schedule_next) m_t_prime = m_t / (1. - m_schedule_next)
v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g) v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t)) v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
m_t_bar = ( m_t_bar = (1. -
1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
self.updates.append(state_ops.assign(m, m_t)) self.updates.append(state_ops.assign(m, m_t))
self.updates.append(state_ops.assign(v, v_t)) self.updates.append(state_ops.assign(v, v_t))
@ -712,8 +709,7 @@ class Nadam(Optimizer):
class TFOptimizer(Optimizer, trackable.Trackable): class TFOptimizer(Optimizer, trackable.Trackable):
"""Wrapper class for native TensorFlow optimizers. """Wrapper class for native TensorFlow optimizers."""
"""
def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called
self.optimizer = optimizer self.optimizer = optimizer
@ -792,10 +788,8 @@ def deserialize(config, custom_objects=None):
Arguments: Arguments:
config: Optimizer configuration dictionary. config: Optimizer configuration dictionary.
custom_objects: Optional dictionary mapping custom_objects: Optional dictionary mapping names (strings) to custom
names (strings) to custom objects objects (classes and functions) to be considered during deserialization.
(classes and functions)
to be considered during deserialization.
Returns: Returns:
A Keras Optimizer instance. A Keras Optimizer instance.
@ -828,10 +822,9 @@ def get(identifier):
Arguments: Arguments:
identifier: Optimizer identifier, one of identifier: Optimizer identifier, one of
- String: name of an optimizer - String: name of an optimizer
- Dictionary: configuration dictionary. - Dictionary: configuration dictionary. - Keras Optimizer instance (it
- Keras Optimizer instance (it will be returned unchanged). will be returned unchanged). - TensorFlow Optimizer instance (it
- TensorFlow Optimizer instance will be wrapped as a Keras Optimizer).
(it will be wrapped as a Keras Optimizer).
Returns: Returns:
A Keras Optimizer instance. A Keras Optimizer instance.