Move all the docstrings in init to the class docstrings except Args.
PiperOrigin-RevId: 301211245 Change-Id: I1007dbe7536bc8439c95fb6789a73c36dfc4224b
This commit is contained in:
parent
7680958e81
commit
a3d80814af
@ -52,10 +52,23 @@ class Adadelta(optimizer_v2.OptimizerV2):
|
||||
$$E[\Delta x^2]_t := \rho * E[\Delta x^2]_{t-1} + (1 - \rho) * \Delta x_t^2$$
|
||||
$$x_t := x_{t-1} + \Delta x_{t}$$
|
||||
|
||||
Adadelta is a more robust extension of Adagrad that adapts learning rates
|
||||
based on a moving window of gradient updates, instead of accumulating all
|
||||
past gradients. This way, Adadelta continues learning even when many updates
|
||||
have been done. Compared to Adagrad, in the original version of Adadelta you
|
||||
don't have to set an initial learning rate. In this version, initial
|
||||
learning rate can be set, as in most other Keras optimizers.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
|
||||
each be a callable that takes no arguments and returns the actual value to
|
||||
use. This can be useful for changing these values across different
|
||||
invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
|
||||
References
|
||||
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
|
||||
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
|
||||
|
||||
"""
|
||||
|
||||
_HAS_ALL_REDUCE_SUM_GRAD = True
|
||||
@ -68,13 +81,6 @@ class Adadelta(optimizer_v2.OptimizerV2):
|
||||
**kwargs):
|
||||
"""Construct a new Adadelta optimizer.
|
||||
|
||||
Adadelta is a more robust extension of Adagrad that adapts learning rates
|
||||
based on a moving window of gradient updates, instead of accumulating all
|
||||
past gradients. This way, Adadelta continues learning even when many updates
|
||||
have been done. Compared to Adagrad, in the original version of Adadelta you
|
||||
don't have to set an initial learning rate. In this version, initial
|
||||
learning rate can be set, as in most other Keras optimizers.
|
||||
|
||||
Args:
|
||||
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
||||
`tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
|
||||
@ -89,13 +95,6 @@ class Adadelta(optimizer_v2.OptimizerV2):
|
||||
gradients by value, `decay` is included for backward compatibility to
|
||||
allow time inverse decay of learning rate. `lr` is included for backward
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
|
||||
each be a callable that takes no arguments and returns the actual value to
|
||||
use. This can be useful for changing these values across different
|
||||
invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(Adadelta, self).__init__(name, **kwargs)
|
||||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||
|
@ -47,6 +47,13 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
$$accum_{g_t} := accum_{g_{t-1}} + g^2$$
|
||||
$$\theta_t := \theta_{t-1} - lr * g / (\sqrt{accum_{g_t}} + \epsilon)$$
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate` can be a callable that
|
||||
takes no arguments and returns the actual value to use. This can be useful
|
||||
for changing these values across different invocations of optimizer
|
||||
functions.
|
||||
@end_compatibility
|
||||
|
||||
References:
|
||||
|
||||
* [Paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
|
||||
@ -80,13 +87,6 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
|
||||
Raises:
|
||||
ValueError: If the `initial_accumulator_value` or `epsilon` is invalid.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate` can be a callable that
|
||||
takes no arguments and returns the actual value to use. This can be useful
|
||||
for changing these values across different invocations of optimizer
|
||||
functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
if initial_accumulator_value < 0.0:
|
||||
raise ValueError('initial_accumulator_value must be non-negative: %s' %
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@keras_export('keras.optimizers.Adam')
|
||||
class Adam(optimizer_v2.OptimizerV2):
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
r"""Optimizer that implements the Adam algorithm.
|
||||
|
||||
Adam optimization is a stochastic gradient descent method that is based on
|
||||
adaptive estimation of first-order and second-order moments.
|
||||
@ -43,6 +43,63 @@ class Adam(optimizer_v2.OptimizerV2):
|
||||
|
||||
For AMSGrad see [On The Convergence Of Adam And Beyond.
|
||||
Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ).
|
||||
|
||||
**If amsgrad = False**:
|
||||
|
||||
initialize $m_0$ as 1st moment vector
|
||||
initialize $v_0$ as 2nd moment vector
|
||||
|
||||
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||
described at the end of section 2 of the paper:
|
||||
|
||||
$$lr_t = \mathrm{learning\_rate} *
|
||||
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
**If amsgrad = True**:
|
||||
|
||||
initialize $m_0$ as 1st moment vector
|
||||
initialize $v_0$ as 2nd moment vector
|
||||
initialize $\hat{v}_0$ as 2nd moment vector
|
||||
|
||||
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||
described at the end of section 2 of the paper:
|
||||
|
||||
$$lr_t = \mathrm{learning\_rate} *
|
||||
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||
|
||||
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||
$$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$
|
||||
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-7 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Usage:
|
||||
|
||||
>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
|
||||
>>> var1 = tf.Variable(10.0)
|
||||
>>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1
|
||||
>>> step_count = opt.minimize(loss, [var1]).numpy()
|
||||
>>> # The first step is `-learning_rate*sign(grad)`
|
||||
>>> var1.numpy()
|
||||
9.9
|
||||
"""
|
||||
|
||||
_HAS_ALL_REDUCE_SUM_GRAD = True
|
||||
@ -55,64 +112,7 @@ class Adam(optimizer_v2.OptimizerV2):
|
||||
amsgrad=False,
|
||||
name='Adam',
|
||||
**kwargs):
|
||||
r"""Construct a new Adam optimizer.
|
||||
|
||||
If amsgrad = False:
|
||||
|
||||
initialize $m_0$ as 1st moment vector
|
||||
initialize $v_0$ as 2nd moment vector
|
||||
|
||||
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||
described at the end of section 2 of the paper:
|
||||
|
||||
$$lr_t = \mathrm{learning\_rate} *
|
||||
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
If amsgrad = True:
|
||||
|
||||
initialize $m_0$ as 1st moment vector
|
||||
initialize $v_0$ as 2nd moment vector
|
||||
initialize $\hat{v}_0$ as 2nd moment vector
|
||||
|
||||
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||
described at the end of section 2 of the paper:
|
||||
|
||||
$$lr_t = \mathrm{learning\_rate} *
|
||||
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||
|
||||
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||
$$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$
|
||||
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-7 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Usage:
|
||||
|
||||
>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
|
||||
>>> var1 = tf.Variable(10.0)
|
||||
>>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1
|
||||
>>> step_count = opt.minimize(loss, [var1]).numpy()
|
||||
>>> # The first step is `-learning_rate*sign(grad)`
|
||||
>>> var1.numpy()
|
||||
9.9
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
Args:
|
||||
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
||||
@ -138,7 +138,6 @@ class Adam(optimizer_v2.OptimizerV2):
|
||||
gradients by value, `decay` is included for backward compatibility to
|
||||
allow time inverse decay of learning rate. `lr` is included for backward
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
|
||||
"""
|
||||
|
||||
super(Adam, self).__init__(name, **kwargs)
|
||||
|
@ -37,6 +37,37 @@ class Adamax(optimizer_v2.OptimizerV2):
|
||||
Default parameters follow those provided in the paper.
|
||||
Adamax is sometimes superior to adam, specially in models with embeddings.
|
||||
|
||||
Initialization:
|
||||
|
||||
```
|
||||
m_0 <- 0 (Initialize initial 1st moment vector)
|
||||
v_0 <- 0 (Initialize the exponentially weighted infinity norm)
|
||||
t <- 0 (Initialize timestep)
|
||||
```
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section 7.1 of the paper:
|
||||
|
||||
```
|
||||
t <- t + 1
|
||||
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
v_t <- max(beta2 * v_{t-1}, abs(g))
|
||||
variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
|
||||
```
|
||||
|
||||
Similar to AdamOptimizer, the epsilon is added for numerical stability
|
||||
(especially to get rid of division by zero when v_t = 0).
|
||||
|
||||
Contrast to AdamOptimizer, the sparse implementation of this algorithm
|
||||
(used when the gradient is an IndexedSlices object, typically because of
|
||||
`tf.gather` or an embedding lookup in the forward pass) only updates
|
||||
variable slices and corresponding `m_t`, `v_t` terms when that part of
|
||||
the variable was used in the forward pass. This means that the sparse
|
||||
behavior is contrast to the dense behavior (similar to some momentum
|
||||
implementations which ignore momentum unless a variable slice was actually
|
||||
used).
|
||||
|
||||
References
|
||||
see Section 7 of [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
@ -53,37 +84,6 @@ class Adamax(optimizer_v2.OptimizerV2):
|
||||
**kwargs):
|
||||
"""Construct a new Adamax optimizer.
|
||||
|
||||
Initialization:
|
||||
|
||||
```
|
||||
m_0 <- 0 (Initialize initial 1st moment vector)
|
||||
v_0 <- 0 (Initialize the exponentially weighted infinity norm)
|
||||
t <- 0 (Initialize timestep)
|
||||
```
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section 7.1 of the paper:
|
||||
|
||||
```
|
||||
t <- t + 1
|
||||
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
v_t <- max(beta2 * v_{t-1}, abs(g))
|
||||
variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
|
||||
```
|
||||
|
||||
Similar to AdamOptimizer, the epsilon is added for numerical stability
|
||||
(especially to get rid of division by zero when v_t = 0).
|
||||
|
||||
Contrast to AdamOptimizer, the sparse implementation of this algorithm
|
||||
(used when the gradient is an IndexedSlices object, typically because of
|
||||
`tf.gather` or an embedding lookup in the forward pass) only updates
|
||||
variable slices and corresponding `m_t`, `v_t` terms when that part of
|
||||
the variable was used in the forward pass. This means that the sparse
|
||||
behavior is contrast to the dense behavior (similar to some momentum
|
||||
implementations which ignore momentum unless a variable slice was actually
|
||||
used).
|
||||
|
||||
Args:
|
||||
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
||||
`tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
|
||||
|
@ -52,6 +52,9 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
Check the documentation for the l2_shrinkage_regularization_strength
|
||||
parameter for more details when shrinkage is enabled, where gradient is
|
||||
replaced with gradient_with_shrinkage.
|
||||
|
||||
References: See
|
||||
[paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -100,10 +103,6 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
|
||||
Raises:
|
||||
ValueError: If one of the arguments is invalid.
|
||||
|
||||
References
|
||||
See [paper]
|
||||
(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
|
||||
"""
|
||||
super(Ftrl, self).__init__(name, **kwargs)
|
||||
|
||||
|
@ -62,7 +62,55 @@ class LearningRateSchedule(object):
|
||||
|
||||
@keras_export("keras.optimizers.schedules.ExponentialDecay")
|
||||
class ExponentialDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses an exponential decay schedule."""
|
||||
"""A LearningRateSchedule that uses an exponential decay schedule.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies an exponential decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate * decay_rate ^ (step / decay_steps)
|
||||
```
|
||||
|
||||
If the argument `staircase` is `True`, then `step / decay_steps` is
|
||||
an integer division and the decayed learning rate follows a
|
||||
staircase function.
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: When fitting a Keras model, decay every 100000 steps with a base
|
||||
of 0.96:
|
||||
|
||||
```python
|
||||
initial_learning_rate = 0.1
|
||||
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||
initial_learning_rate,
|
||||
decay_steps=100000,
|
||||
decay_rate=0.96,
|
||||
staircase=True)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
The learning rate schedule is also serializable and deserializable using
|
||||
`tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -73,48 +121,6 @@ class ExponentialDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies exponential decay to the learning rate.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies an exponential decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate * decay_rate ^ (step / decay_steps)
|
||||
```
|
||||
|
||||
If the argument `staircase` is `True`, then `step / decay_steps` is
|
||||
an integer division and the decayed learning rate follows a
|
||||
staircase function.
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: When fitting a Keras model, decay every 100000 steps with a base
|
||||
of 0.96:
|
||||
|
||||
```python
|
||||
initial_learning_rate = 0.1
|
||||
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||
initial_learning_rate,
|
||||
decay_steps=100000,
|
||||
decay_rate=0.96,
|
||||
staircase=True)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
The learning rate schedule is also serializable and deserializable using
|
||||
`tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
@ -126,11 +132,6 @@ class ExponentialDecay(LearningRateSchedule):
|
||||
intervals
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'ExponentialDecay'.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(ExponentialDecay, self).__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
@ -166,7 +167,41 @@ class ExponentialDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
|
||||
class PiecewiseConstantDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a piecewise constant decay schedule."""
|
||||
"""A LearningRateSchedule that uses a piecewise constant decay schedule.
|
||||
|
||||
The function returns a 1-arg callable to compute the piecewise constant
|
||||
when passed the current optimizer step. This can be useful for changing the
|
||||
learning rate value across different invocations of optimizer functions.
|
||||
|
||||
Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
|
||||
for the next 10000 steps, and 0.1 for any additional steps.
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
boundaries = [100000, 110000]
|
||||
values = [1.0, 0.5, 0.1]
|
||||
learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
|
||||
boundaries, values)
|
||||
|
||||
# Later, whenever we perform an optimization step, we pass in the step.
|
||||
learning_rate = learning_rate_fn(step)
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as the boundary tensors.
|
||||
|
||||
The output of the 1-arg function that takes the `step`
|
||||
is `values[0]` when `step <= boundaries[0]`,
|
||||
`values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
|
||||
and values[-1] when `step > boundaries[-1]`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -175,29 +210,6 @@ class PiecewiseConstantDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Piecewise constant from boundaries and interval values.
|
||||
|
||||
The function returns a 1-arg callable to compute the piecewise constant
|
||||
when passed the current optimizer step. This can be useful for changing the
|
||||
learning rate value across different invocations of optimizer functions.
|
||||
|
||||
Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
|
||||
for the next 10000 steps, and 0.1 for any additional steps.
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
boundaries = [100000, 110000]
|
||||
values = [1.0, 0.5, 0.1]
|
||||
learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
|
||||
boundaries, values)
|
||||
|
||||
# Later, whenever we perform an optimization step, we pass in the step.
|
||||
learning_rate = learning_rate_fn(step)
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
|
||||
increasing entries, and with all elements having the same type as the
|
||||
@ -209,16 +221,6 @@ class PiecewiseConstantDecay(LearningRateSchedule):
|
||||
name: A string. Optional name of the operation. Defaults to
|
||||
'PiecewiseConstant'.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as the boundary tensors.
|
||||
|
||||
The output of the 1-arg function that takes the `step`
|
||||
is `values[0]` when `step <= boundaries[0]`,
|
||||
`values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
|
||||
and values[-1] when `step > boundaries[-1]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the number of elements in the lists do not match.
|
||||
"""
|
||||
@ -265,7 +267,75 @@ class PiecewiseConstantDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.optimizers.schedules.PolynomialDecay")
|
||||
class PolynomialDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a polynomial decay schedule."""
|
||||
"""A LearningRateSchedule that uses a polynomial decay schedule.
|
||||
|
||||
It is commonly observed that a monotonically decreasing learning rate, whose
|
||||
degree of change is carefully chosen, results in a better performing model.
|
||||
This schedule applies a polynomial decay function to an optimizer step,
|
||||
given a provided `initial_learning_rate`, to reach an `end_learning_rate`
|
||||
in the given `decay_steps`.
|
||||
|
||||
It requires a `step` value to compute the decayed learning rate. You
|
||||
can just pass a TensorFlow variable that you increment at each training
|
||||
step.
|
||||
|
||||
The schedule is a 1-arg callable that produces a decayed learning rate
|
||||
when passed the current optimizer step. This can be useful for changing the
|
||||
learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
return ((initial_learning_rate - end_learning_rate) *
|
||||
(1 - step / decay_steps) ^ (power)
|
||||
) + end_learning_rate
|
||||
```
|
||||
|
||||
If `cycle` is True then a multiple of `decay_steps` is used, the first one
|
||||
that is bigger than `step`.
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
decay_steps = decay_steps * ceil(step / decay_steps)
|
||||
return ((initial_learning_rate - end_learning_rate) *
|
||||
(1 - step / decay_steps) ^ (power)
|
||||
) + end_learning_rate
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
|
||||
sqrt (i.e. power=0.5):
|
||||
|
||||
```python
|
||||
...
|
||||
starter_learning_rate = 0.1
|
||||
end_learning_rate = 0.01
|
||||
decay_steps = 10000
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
starter_learning_rate,
|
||||
decay_steps,
|
||||
end_learning_rate,
|
||||
power=0.5)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(
|
||||
learning_rate=learning_rate_fn),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
The learning rate schedule is also serializable and deserializable using
|
||||
`tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -277,68 +347,6 @@ class PolynomialDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies a polynomial decay to the learning rate.
|
||||
|
||||
It is commonly observed that a monotonically decreasing learning rate, whose
|
||||
degree of change is carefully chosen, results in a better performing model.
|
||||
This schedule applies a polynomial decay function to an optimizer step,
|
||||
given a provided `initial_learning_rate`, to reach an `end_learning_rate`
|
||||
in the given `decay_steps`.
|
||||
|
||||
It requires a `step` value to compute the decayed learning rate. You
|
||||
can just pass a TensorFlow variable that you increment at each training
|
||||
step.
|
||||
|
||||
The schedule is a 1-arg callable that produces a decayed learning rate
|
||||
when passed the current optimizer step. This can be useful for changing the
|
||||
learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
return ((initial_learning_rate - end_learning_rate) *
|
||||
(1 - step / decay_steps) ^ (power)
|
||||
) + end_learning_rate
|
||||
```
|
||||
|
||||
If `cycle` is True then a multiple of `decay_steps` is used, the first one
|
||||
that is bigger than `step`.
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
decay_steps = decay_steps * ceil(step / decay_steps)
|
||||
return ((initial_learning_rate - end_learning_rate) *
|
||||
(1 - step / decay_steps) ^ (power)
|
||||
) + end_learning_rate
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
|
||||
sqrt (i.e. power=0.5):
|
||||
|
||||
```python
|
||||
...
|
||||
starter_learning_rate = 0.1
|
||||
end_learning_rate = 0.01
|
||||
decay_steps = 10000
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
starter_learning_rate,
|
||||
decay_steps,
|
||||
end_learning_rate,
|
||||
power=0.5)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(
|
||||
learning_rate=learning_rate_fn),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
The learning rate schedule is also serializable and deserializable using
|
||||
`tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
@ -351,11 +359,6 @@ class PolynomialDecay(LearningRateSchedule):
|
||||
cycle: A boolean, whether or not it should cycle beyond decay_steps.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'PolynomialDecay'.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(PolynomialDecay, self).__init__()
|
||||
|
||||
@ -408,7 +411,56 @@ class PolynomialDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.optimizers.schedules.InverseTimeDecay")
|
||||
class InverseTimeDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses an inverse time decay schedule."""
|
||||
"""A LearningRateSchedule that uses an inverse time decay schedule.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies the inverse decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate / (1 + decay_rate * step / decay_step)
|
||||
```
|
||||
|
||||
or, if `staircase` is `True`, as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
|
||||
|
||||
```python
|
||||
...
|
||||
initial_learning_rate = 0.1
|
||||
decay_steps = 1.0
|
||||
decay_rate = 0.5
|
||||
learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
|
||||
initial_learning_rate, decay_steps, decay_rate)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(
|
||||
learning_rate=learning_rate_fn),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -419,49 +471,6 @@ class InverseTimeDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies inverse time decay to the initial learning rate.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies the inverse decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate / (1 + decay_rate * step / decay_step)
|
||||
```
|
||||
|
||||
or, if `staircase` is `True`, as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
return initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate.
|
||||
Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
|
||||
|
||||
```python
|
||||
...
|
||||
initial_learning_rate = 0.1
|
||||
decay_steps = 1.0
|
||||
decay_rate = 0.5
|
||||
learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
|
||||
initial_learning_rate, decay_steps, decay_rate)
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(
|
||||
learning_rate=learning_rate_fn),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit(data, labels, epochs=5)
|
||||
```
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
@ -471,11 +480,6 @@ class InverseTimeDecay(LearningRateSchedule):
|
||||
continuous, fashion.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'InverseTimeDecay'.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(InverseTimeDecay, self).__init__()
|
||||
|
||||
@ -513,7 +517,47 @@ class InverseTimeDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.experimental.CosineDecay")
|
||||
class CosineDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a cosine decay schedule."""
|
||||
"""A LearningRateSchedule that uses a cosine decay schedule.
|
||||
|
||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a cosine decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate, decay_steps)
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -523,40 +567,6 @@ class CosineDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies cosine decay to the learning rate.
|
||||
|
||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a cosine decay function
|
||||
to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate, decay_steps)
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` Tensor or a
|
||||
Python number. The initial learning rate.
|
||||
@ -565,10 +575,6 @@ class CosineDecay(LearningRateSchedule):
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
Minimum learning rate value as a fraction of initial_learning_rate.
|
||||
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(CosineDecay, self).__init__()
|
||||
|
||||
@ -604,7 +610,45 @@ class CosineDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.experimental.CosineDecayRestarts")
|
||||
class CosineDecayRestarts(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a cosine decay schedule with restarts."""
|
||||
"""A LearningRateSchedule that uses a cosine decay schedule with restarts.
|
||||
|
||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a cosine decay function with
|
||||
restarts to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
|
||||
The learning rate multiplier first decays
|
||||
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
|
||||
restart is performed. Each new warm restart runs for `t_mul` times more
|
||||
steps and with `m_mul` times smaller initial learning rate.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
first_decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.CosineDecayRestarts(
|
||||
initial_learning_rate,
|
||||
first_decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -616,38 +660,6 @@ class CosineDecayRestarts(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies cosine decay with restarts to the learning rate.
|
||||
|
||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a cosine decay function with
|
||||
restarts to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
|
||||
The learning rate multiplier first decays
|
||||
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
|
||||
restart is performed. Each new warm restart runs for `t_mul` times more
|
||||
steps and with `m_mul` times smaller initial learning rate.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
first_decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.CosineDecayRestarts(
|
||||
initial_learning_rate,
|
||||
first_decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
|
||||
number. The initial learning rate.
|
||||
@ -660,10 +672,6 @@ class CosineDecayRestarts(LearningRateSchedule):
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
Minimum learning rate value as a fraction of the initial_learning_rate.
|
||||
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(CosineDecayRestarts, self).__init__()
|
||||
|
||||
@ -728,7 +736,57 @@ class CosineDecayRestarts(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.experimental.LinearCosineDecay")
|
||||
class LinearCosineDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a linear cosine decay schedule."""
|
||||
"""A LearningRateSchedule that uses a linear cosine decay schedule.
|
||||
|
||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||
https://arxiv.org/abs/1709.07417
|
||||
|
||||
For the idea of warm starts here controlled by `num_periods`,
|
||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
Note that linear cosine decay is more aggressive than cosine decay and
|
||||
larger initial learning rates can typically be used.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a linear cosine decay
|
||||
function to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
linear_decay = (decay_steps - step) / decay_steps
|
||||
cosine_decay = 0.5 * (
|
||||
1 + cos(pi * 2 * num_periods * step / decay_steps))
|
||||
decayed = (alpha + linear_decay) * cosine_decay + beta
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.LinearCosineDecay(
|
||||
initial_learning_rate, decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -740,50 +798,6 @@ class LinearCosineDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies linear cosine decay to the learning rate.
|
||||
|
||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||
https://arxiv.org/abs/1709.07417
|
||||
|
||||
For the idea of warm starts here controlled by `num_periods`,
|
||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
Note that linear cosine decay is more aggressive than cosine decay and
|
||||
larger initial learning rates can typically be used.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a linear cosine decay
|
||||
function to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
linear_decay = (decay_steps - step) / decay_steps
|
||||
cosine_decay = 0.5 * (
|
||||
1 + cos(pi * 2 * num_periods * step / decay_steps))
|
||||
decayed = (alpha + linear_decay) * cosine_decay + beta
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.LinearCosineDecay(
|
||||
initial_learning_rate, decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
|
||||
number. The initial learning rate.
|
||||
@ -795,10 +809,6 @@ class LinearCosineDecay(LearningRateSchedule):
|
||||
beta: See computation above.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'LinearCosineDecay'.
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(LinearCosineDecay, self).__init__()
|
||||
|
||||
@ -844,7 +854,59 @@ class LinearCosineDecay(LearningRateSchedule):
|
||||
|
||||
@keras_export("keras.experimental.NoisyLinearCosineDecay")
|
||||
class NoisyLinearCosineDecay(LearningRateSchedule):
|
||||
"""A LearningRateSchedule that uses a noisy linear cosine decay schedule."""
|
||||
"""A LearningRateSchedule that uses a noisy linear cosine decay schedule.
|
||||
|
||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||
https://arxiv.org/abs/1709.07417
|
||||
|
||||
For the idea of warm starts here controlled by `num_periods`,
|
||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
Note that linear cosine decay is more aggressive than cosine decay and
|
||||
larger initial learning rates can typically be used.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a noisy linear cosine decay
|
||||
function to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
linear_decay = (decay_steps - step) / decay_steps)
|
||||
cosine_decay = 0.5 * (
|
||||
1 + cos(pi * 2 * num_periods * step / decay_steps))
|
||||
decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
where eps_t is 0-centered gaussian noise with variance
|
||||
initial_variance / (1 + global_step) ** variance_decay
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.NoisyLinearCosineDecay(
|
||||
initial_learning_rate, decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -858,52 +920,6 @@ class NoisyLinearCosineDecay(LearningRateSchedule):
|
||||
name=None):
|
||||
"""Applies noisy linear cosine decay to the learning rate.
|
||||
|
||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||
https://arxiv.org/abs/1709.07417
|
||||
|
||||
For the idea of warm starts here controlled by `num_periods`,
|
||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||
|
||||
Note that linear cosine decay is more aggressive than cosine decay and
|
||||
larger initial learning rates can typically be used.
|
||||
|
||||
When training a model, it is often recommended to lower the learning rate as
|
||||
the training progresses. This schedule applies a noisy linear cosine decay
|
||||
function to an optimizer step, given a provided initial learning rate.
|
||||
It requires a `step` value to compute the decayed learning rate. You can
|
||||
just pass a TensorFlow variable that you increment at each training step.
|
||||
|
||||
The schedule a 1-arg callable that produces a decayed learning
|
||||
rate when passed the current optimizer step. This can be useful for changing
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
It is computed as:
|
||||
|
||||
```python
|
||||
def decayed_learning_rate(step):
|
||||
step = min(step, decay_steps)
|
||||
linear_decay = (decay_steps - step) / decay_steps)
|
||||
cosine_decay = 0.5 * (
|
||||
1 + cos(pi * 2 * num_periods * step / decay_steps))
|
||||
decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
|
||||
return initial_learning_rate * decayed
|
||||
```
|
||||
where eps_t is 0-centered gaussian noise with variance
|
||||
initial_variance / (1 + global_step) ** variance_decay
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
decay_steps = 1000
|
||||
lr_decayed_fn = (
|
||||
tf.keras.experimental.NoisyLinearCosineDecay(
|
||||
initial_learning_rate, decay_steps))
|
||||
```
|
||||
|
||||
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
|
||||
as the learning rate. The learning rate schedule is also serializable and
|
||||
deserializable using `tf.keras.optimizers.schedules.serialize` and
|
||||
`tf.keras.optimizers.schedules.deserialize`.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
|
||||
number. The initial learning rate.
|
||||
@ -917,10 +933,6 @@ class NoisyLinearCosineDecay(LearningRateSchedule):
|
||||
beta: See computation above.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'NoisyLinearCosineDecay'.
|
||||
Returns:
|
||||
A 1-arg callable learning rate schedule that takes the current optimizer
|
||||
step and outputs the decayed learning rate, a scalar `Tensor` of the same
|
||||
type as `initial_learning_rate`.
|
||||
"""
|
||||
super(NoisyLinearCosineDecay, self).__init__()
|
||||
|
||||
|
@ -65,6 +65,18 @@ class RMSprop(optimizer_v2.OptimizerV2):
|
||||
\mathrm{learning\_rate} * g_t / sqrt(rms_t - mg_t^2 + \epsilon)$$
|
||||
$$\theta_t = \theta_{t-1} - mom_t$$
|
||||
|
||||
Note that in the dense implementation of this algorithm, variables and their
|
||||
corresponding accumulators (momentum, gradient moving average, square
|
||||
gradient moving average) will be updated even if the gradient is zero
|
||||
(i.e. accumulators will decay, momentum will be applied). The sparse
|
||||
implementation (used when the gradient is an `IndexedSlices` object,
|
||||
typically because of `tf.gather` or an embedding lookup in the forward pass)
|
||||
will not update variable slices or their accumulators unless those slices
|
||||
were used in the forward pass (nor is there an "eventual" correction to
|
||||
account for these omitted updates). This leads to more efficient updates for
|
||||
large embedding lookup tables (where most of the slices are not accessed in
|
||||
a particular graph execution), but differs from the published algorithm.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1)
|
||||
@ -91,18 +103,6 @@ class RMSprop(optimizer_v2.OptimizerV2):
|
||||
**kwargs):
|
||||
"""Construct a new RMSprop optimizer.
|
||||
|
||||
Note that in the dense implementation of this algorithm, variables and their
|
||||
corresponding accumulators (momentum, gradient moving average, square
|
||||
gradient moving average) will be updated even if the gradient is zero
|
||||
(i.e. accumulators will decay, momentum will be applied). The sparse
|
||||
implementation (used when the gradient is an `IndexedSlices` object,
|
||||
typically because of `tf.gather` or an embedding lookup in the forward pass)
|
||||
will not update variable slices or their accumulators unless those slices
|
||||
were used in the forward pass (nor is there an "eventual" correction to
|
||||
account for these omitted updates). This leads to more efficient updates for
|
||||
large embedding lookup tables (where most of the slices are not accessed in
|
||||
a particular graph execution), but differs from the published algorithm.
|
||||
|
||||
Args:
|
||||
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
||||
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
|
||||
|
Loading…
x
Reference in New Issue
Block a user