Move all the docstrings in init to the class docstrings except Args.

PiperOrigin-RevId: 301211245
Change-Id: I1007dbe7536bc8439c95fb6789a73c36dfc4224b
This commit is contained in:
Yash Katariya 2020-03-16 12:13:28 -07:00 committed by TensorFlower Gardener
parent 7680958e81
commit a3d80814af
7 changed files with 513 additions and 504 deletions

View File

@ -52,10 +52,23 @@ class Adadelta(optimizer_v2.OptimizerV2):
$$E[\Delta x^2]_t := \rho * E[\Delta x^2]_{t-1} + (1 - \rho) * \Delta x_t^2$$
$$x_t := x_{t-1} + \Delta x_{t}$$
Adadelta is a more robust extension of Adagrad that adapts learning rates
based on a moving window of gradient updates, instead of accumulating all
past gradients. This way, Adadelta continues learning even when many updates
have been done. Compared to Adagrad, in the original version of Adadelta you
don't have to set an initial learning rate. In this version, initial
learning rate can be set, as in most other Keras optimizers.
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
each be a callable that takes no arguments and returns the actual value to
use. This can be useful for changing these values across different
invocations of optimizer functions.
@end_compatibility
References
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
_HAS_ALL_REDUCE_SUM_GRAD = True
@ -68,13 +81,6 @@ class Adadelta(optimizer_v2.OptimizerV2):
**kwargs):
"""Construct a new Adadelta optimizer.
Adadelta is a more robust extension of Adagrad that adapts learning rates
based on a moving window of gradient updates, instead of accumulating all
past gradients. This way, Adadelta continues learning even when many updates
have been done. Compared to Adagrad, in the original version of Adadelta you
don't have to set an initial learning rate. In this version, initial
learning rate can be set, as in most other Keras optimizers.
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
@ -89,13 +95,6 @@ class Adadelta(optimizer_v2.OptimizerV2):
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
each be a callable that takes no arguments and returns the actual value to
use. This can be useful for changing these values across different
invocations of optimizer functions.
@end_compatibility
"""
super(Adadelta, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))

View File

@ -47,6 +47,13 @@ class Adagrad(optimizer_v2.OptimizerV2):
$$accum_{g_t} := accum_{g_{t-1}} + g^2$$
$$\theta_t := \theta_{t-1} - lr * g / (\sqrt{accum_{g_t}} + \epsilon)$$
@compatibility(eager)
When eager execution is enabled, `learning_rate` can be a callable that
takes no arguments and returns the actual value to use. This can be useful
for changing these values across different invocations of optimizer
functions.
@end_compatibility
References:
* [Paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
@ -80,13 +87,6 @@ class Adagrad(optimizer_v2.OptimizerV2):
Raises:
ValueError: If the `initial_accumulator_value` or `epsilon` is invalid.
@compatibility(eager)
When eager execution is enabled, `learning_rate` can be a callable that
takes no arguments and returns the actual value to use. This can be useful
for changing these values across different invocations of optimizer
functions.
@end_compatibility
"""
if initial_accumulator_value < 0.0:
raise ValueError('initial_accumulator_value must be non-negative: %s' %

View File

@ -30,7 +30,7 @@ from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.optimizers.Adam')
class Adam(optimizer_v2.OptimizerV2):
"""Optimizer that implements the Adam algorithm.
r"""Optimizer that implements the Adam algorithm.
Adam optimization is a stochastic gradient descent method that is based on
adaptive estimation of first-order and second-order moments.
@ -43,21 +43,8 @@ class Adam(optimizer_v2.OptimizerV2):
For AMSGrad see [On The Convergence Of Adam And Beyond.
Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ).
"""
_HAS_ALL_REDUCE_SUM_GRAD = True
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
name='Adam',
**kwargs):
r"""Construct a new Adam optimizer.
If amsgrad = False:
**If amsgrad = False**:
initialize $m_0$ as 1st moment vector
initialize $v_0$ as 2nd moment vector
@ -71,7 +58,7 @@ class Adam(optimizer_v2.OptimizerV2):
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
If amsgrad = True:
**If amsgrad = True**:
initialize $m_0$ as 1st moment vector
initialize $v_0$ as 2nd moment vector
@ -113,6 +100,19 @@ class Adam(optimizer_v2.OptimizerV2):
>>> # The first step is `-learning_rate*sign(grad)`
>>> var1.numpy()
9.9
"""
_HAS_ALL_REDUCE_SUM_GRAD = True
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
name='Adam',
**kwargs):
"""Construct a new Adam optimizer.
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
@ -138,7 +138,6 @@ class Adam(optimizer_v2.OptimizerV2):
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super(Adam, self).__init__(name, **kwargs)

View File

@ -37,22 +37,6 @@ class Adamax(optimizer_v2.OptimizerV2):
Default parameters follow those provided in the paper.
Adamax is sometimes superior to adam, specially in models with embeddings.
References
see Section 7 of [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
_HAS_ALL_REDUCE_SUM_GRAD = True
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
name='Adamax',
**kwargs):
"""Construct a new Adamax optimizer.
Initialization:
```
@ -84,6 +68,22 @@ class Adamax(optimizer_v2.OptimizerV2):
implementations which ignore momentum unless a variable slice was actually
used).
References
see Section 7 of [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
_HAS_ALL_REDUCE_SUM_GRAD = True
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
name='Adamax',
**kwargs):
"""Construct a new Adamax optimizer.
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.

View File

@ -52,6 +52,9 @@ class Ftrl(optimizer_v2.OptimizerV2):
Check the documentation for the l2_shrinkage_regularization_strength
parameter for more details when shrinkage is enabled, where gradient is
replaced with gradient_with_shrinkage.
References: See
[paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
"""
def __init__(self,
@ -100,10 +103,6 @@ class Ftrl(optimizer_v2.OptimizerV2):
Raises:
ValueError: If one of the arguments is invalid.
References
See [paper]
(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
"""
super(Ftrl, self).__init__(name, **kwargs)

View File

@ -62,16 +62,7 @@ class LearningRateSchedule(object):
@keras_export("keras.optimizers.schedules.ExponentialDecay")
class ExponentialDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses an exponential decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
decay_rate,
staircase=False,
name=None):
"""Applies exponential decay to the learning rate.
"""A LearningRateSchedule that uses an exponential decay schedule.
When training a model, it is often recommended to lower the learning rate as
the training progresses. This schedule applies an exponential decay function
@ -115,6 +106,21 @@ class ExponentialDecay(LearningRateSchedule):
`tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
decay_rate,
staircase=False,
name=None):
"""Applies exponential decay to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
@ -126,11 +132,6 @@ class ExponentialDecay(LearningRateSchedule):
intervals
name: String. Optional name of the operation. Defaults to
'ExponentialDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(ExponentialDecay, self).__init__()
self.initial_learning_rate = initial_learning_rate
@ -166,14 +167,7 @@ class ExponentialDecay(LearningRateSchedule):
@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
class PiecewiseConstantDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a piecewise constant decay schedule."""
def __init__(
self,
boundaries,
values,
name=None):
"""Piecewise constant from boundaries and interval values.
"""A LearningRateSchedule that uses a piecewise constant decay schedule.
The function returns a 1-arg callable to compute the piecewise constant
when passed the current optimizer step. This can be useful for changing the
@ -198,6 +192,24 @@ class PiecewiseConstantDecay(LearningRateSchedule):
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as the boundary tensors.
The output of the 1-arg function that takes the `step`
is `values[0]` when `step <= boundaries[0]`,
`values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
and values[-1] when `step > boundaries[-1]`.
"""
def __init__(
self,
boundaries,
values,
name=None):
"""Piecewise constant from boundaries and interval values.
Args:
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries, and with all elements having the same type as the
@ -209,16 +221,6 @@ class PiecewiseConstantDecay(LearningRateSchedule):
name: A string. Optional name of the operation. Defaults to
'PiecewiseConstant'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as the boundary tensors.
The output of the 1-arg function that takes the `step`
is `values[0]` when `step <= boundaries[0]`,
`values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
and values[-1] when `step > boundaries[-1]`.
Raises:
ValueError: if the number of elements in the lists do not match.
"""
@ -265,17 +267,7 @@ class PiecewiseConstantDecay(LearningRateSchedule):
@keras_export("keras.optimizers.schedules.PolynomialDecay")
class PolynomialDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a polynomial decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
end_learning_rate=0.0001,
power=1.0,
cycle=False,
name=None):
"""Applies a polynomial decay to the learning rate.
"""A LearningRateSchedule that uses a polynomial decay schedule.
It is commonly observed that a monotonically decreasing learning rate, whose
degree of change is carefully chosen, results in a better performing model.
@ -339,6 +331,22 @@ class PolynomialDecay(LearningRateSchedule):
`tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
end_learning_rate=0.0001,
power=1.0,
cycle=False,
name=None):
"""Applies a polynomial decay to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
@ -351,11 +359,6 @@ class PolynomialDecay(LearningRateSchedule):
cycle: A boolean, whether or not it should cycle beyond decay_steps.
name: String. Optional name of the operation. Defaults to
'PolynomialDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(PolynomialDecay, self).__init__()
@ -408,16 +411,7 @@ class PolynomialDecay(LearningRateSchedule):
@keras_export("keras.optimizers.schedules.InverseTimeDecay")
class InverseTimeDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses an inverse time decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
decay_rate,
staircase=False,
name=None):
"""Applies inverse time decay to the initial learning rate.
"""A LearningRateSchedule that uses an inverse time decay schedule.
When training a model, it is often recommended to lower the learning rate as
the training progresses. This schedule applies the inverse decay function
@ -462,6 +456,21 @@ class InverseTimeDecay(LearningRateSchedule):
model.fit(data, labels, epochs=5)
```
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
decay_rate,
staircase=False,
name=None):
"""Applies inverse time decay to the initial learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
@ -471,11 +480,6 @@ class InverseTimeDecay(LearningRateSchedule):
continuous, fashion.
name: String. Optional name of the operation. Defaults to
'InverseTimeDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(InverseTimeDecay, self).__init__()
@ -513,15 +517,7 @@ class InverseTimeDecay(LearningRateSchedule):
@keras_export("keras.experimental.CosineDecay")
class CosineDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a cosine decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
alpha=0.0,
name=None):
"""Applies cosine decay to the learning rate.
"""A LearningRateSchedule that uses a cosine decay schedule.
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
@ -557,6 +553,20 @@ class CosineDecay(LearningRateSchedule):
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
alpha=0.0,
name=None):
"""Applies cosine decay to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a
Python number. The initial learning rate.
@ -565,10 +575,6 @@ class CosineDecay(LearningRateSchedule):
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of initial_learning_rate.
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(CosineDecay, self).__init__()
@ -604,17 +610,7 @@ class CosineDecay(LearningRateSchedule):
@keras_export("keras.experimental.CosineDecayRestarts")
class CosineDecayRestarts(LearningRateSchedule):
"""A LearningRateSchedule that uses a cosine decay schedule with restarts."""
def __init__(
self,
initial_learning_rate,
first_decay_steps,
t_mul=2.0,
m_mul=1.0,
alpha=0.0,
name=None):
"""Applies cosine decay with restarts to the learning rate.
"""A LearningRateSchedule that uses a cosine decay schedule with restarts.
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
@ -648,6 +644,22 @@ class CosineDecayRestarts(LearningRateSchedule):
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
first_decay_steps,
t_mul=2.0,
m_mul=1.0,
alpha=0.0,
name=None):
"""Applies cosine decay with restarts to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
number. The initial learning rate.
@ -660,10 +672,6 @@ class CosineDecayRestarts(LearningRateSchedule):
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of the initial_learning_rate.
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(CosineDecayRestarts, self).__init__()
@ -728,17 +736,7 @@ class CosineDecayRestarts(LearningRateSchedule):
@keras_export("keras.experimental.LinearCosineDecay")
class LinearCosineDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a linear cosine decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
num_periods=0.5,
alpha=0.0,
beta=0.001,
name=None):
"""Applies linear cosine decay to the learning rate.
"""A LearningRateSchedule that uses a linear cosine decay schedule.
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
https://arxiv.org/abs/1709.07417
@ -784,6 +782,22 @@ class LinearCosineDecay(LearningRateSchedule):
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
num_periods=0.5,
alpha=0.0,
beta=0.001,
name=None):
"""Applies linear cosine decay to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
number. The initial learning rate.
@ -795,10 +809,6 @@ class LinearCosineDecay(LearningRateSchedule):
beta: See computation above.
name: String. Optional name of the operation. Defaults to
'LinearCosineDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(LinearCosineDecay, self).__init__()
@ -844,19 +854,7 @@ class LinearCosineDecay(LearningRateSchedule):
@keras_export("keras.experimental.NoisyLinearCosineDecay")
class NoisyLinearCosineDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a noisy linear cosine decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_steps,
initial_variance=1.0,
variance_decay=0.55,
num_periods=0.5,
alpha=0.0,
beta=0.001,
name=None):
"""Applies noisy linear cosine decay to the learning rate.
"""A LearningRateSchedule that uses a noisy linear cosine decay schedule.
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
https://arxiv.org/abs/1709.07417
@ -904,6 +902,24 @@ class NoisyLinearCosineDecay(LearningRateSchedule):
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
decay_steps,
initial_variance=1.0,
variance_decay=0.55,
num_periods=0.5,
alpha=0.0,
beta=0.001,
name=None):
"""Applies noisy linear cosine decay to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
number. The initial learning rate.
@ -917,10 +933,6 @@ class NoisyLinearCosineDecay(LearningRateSchedule):
beta: See computation above.
name: String. Optional name of the operation. Defaults to
'NoisyLinearCosineDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
super(NoisyLinearCosineDecay, self).__init__()

View File

@ -65,6 +65,18 @@ class RMSprop(optimizer_v2.OptimizerV2):
\mathrm{learning\_rate} * g_t / sqrt(rms_t - mg_t^2 + \epsilon)$$
$$\theta_t = \theta_{t-1} - mom_t$$
Note that in the dense implementation of this algorithm, variables and their
corresponding accumulators (momentum, gradient moving average, square
gradient moving average) will be updated even if the gradient is zero
(i.e. accumulators will decay, momentum will be applied). The sparse
implementation (used when the gradient is an `IndexedSlices` object,
typically because of `tf.gather` or an embedding lookup in the forward pass)
will not update variable slices or their accumulators unless those slices
were used in the forward pass (nor is there an "eventual" correction to
account for these omitted updates). This leads to more efficient updates for
large embedding lookup tables (where most of the slices are not accessed in
a particular graph execution), but differs from the published algorithm.
Usage:
>>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1)
@ -91,18 +103,6 @@ class RMSprop(optimizer_v2.OptimizerV2):
**kwargs):
"""Construct a new RMSprop optimizer.
Note that in the dense implementation of this algorithm, variables and their
corresponding accumulators (momentum, gradient moving average, square
gradient moving average) will be updated even if the gradient is zero
(i.e. accumulators will decay, momentum will be applied). The sparse
implementation (used when the gradient is an `IndexedSlices` object,
typically because of `tf.gather` or an embedding lookup in the forward pass)
will not update variable slices or their accumulators unless those slices
were used in the forward pass (nor is there an "eventual" correction to
account for these omitted updates). This leads to more efficient updates for
large embedding lookup tables (where most of the slices are not accessed in
a particular graph execution), but differs from the published algorithm.
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable