Merge pull request from mrTsjolder:unify_doccitations

PiperOrigin-RevId: 281399355
Change-Id: Icec56ac2937bf6abba6149243c7759c7c6e244ec
This commit is contained in:
TensorFlower Gardener 2019-11-19 18:24:34 -08:00
commit ba63d9082a
29 changed files with 515 additions and 286 deletions

View File

@ -427,7 +427,7 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
"""Construct a subgraph for recursive halving-doubling all-reduce.
The recursive halving-doubling algorithm is described in
http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf
(Thakur et al., 2015).
The concept is to arrange the participating n devices in
a linear sequence where devices exchange data pairwise
@ -459,6 +459,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
Raises:
ValueError: num_devices not a power of 2, or tensor len not divisible
by 2 the proper number of times.
References:
Optimization of Collective Communication Operations in MPICH:
[Thakur et al., 2005]
(https://journals.sagepub.com/doi/abs/10.1177/1094342005051521)
([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf))
"""
devices = [t.device for t in input_tensors]
input_tensors, shape = _flatten_tensors(input_tensors)

View File

@ -29,12 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=['layers.BatchNormalization'])
class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift"
Sergey Ioffe, Christian Szegedy
"""Batch Normalization layer from (Ioffe et al., 2015).
Keras APIs handle BatchNormalization updates to the moving_mean and
moving_variance as part of their `fit()` and `evaluate()` loops. However, if a
@ -49,21 +44,21 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
```
Arguments:
axis: An `int` or list of `int`, the axis or axes that should be
normalized, typically the features axis/axes. For instance, after a
`Conv2D` layer with `data_format="channels_first"`, set `axis=1`. If a
list of axes is provided, each axis in `axis` will be normalized
axis: An `int` or list of `int`, the axis or axes that should be normalized,
typically the features axis/axes. For instance, after a `Conv2D` layer
with `data_format="channels_first"`, set `axis=1`. If a list of axes is
provided, each axis in `axis` will be normalized
simultaneously. Default is `-1` which uses the last axis. Note: when
using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and
`moving_variance` variables are the same rank as the input Tensor, with
dimension size 1 in all reduced (non-axis) dimensions).
using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and
`moving_variance` variables are the same rank as the input Tensor,
with dimension size 1 in all reduced (non-axis) dimensions).
momentum: Momentum for the moving average.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
next layer is linear (also e.g. `nn.relu`), this can be disabled since the
scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
moving_mean_initializer: Initializer for the moving mean.
@ -71,26 +66,26 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: An optional projection function to be applied to the `beta`
weight after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The function
must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are
not safe to use when doing asynchronous distributed training.
weight after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
take as input the unprojected variable and must return the projected
variable (which must have the same shape). Constraints are not safe to use
when doing asynchronous distributed training.
gamma_constraint: An optional projection function to be applied to the
`gamma` weight after being updated by an `Optimizer`.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
`gamma` weight after being updated by an `Optimizer`.
renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra
variables during training. The inference is the same for either value of
this parameter.
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
scalar `Tensors` used to clip the renorm correction. The correction
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
scalar `Tensors` used to clip the renorm correction. The correction `(r,
d)` is used as `corrected_value = normalized_value * r + d`, with `r`
clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
dmax are set to inf, 0, inf, respectively.
renorm_momentum: Momentum used to update the moving means and standard
deviations with renorm. Unlike `momentum`, this affects training
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
deviations with renorm. Unlike `momentum`, this affects training and
should be neither too small (which would add noise) nor too large (which
would give stale estimates). Note that `momentum` is still applied to get
the means and variances for inference.
fused: if `None` or `True`, use a faster, fused implementation if possible.
If `False`, use the system recommended implementation.
trainable: Boolean, if `True` also add variables to the graph collection
@ -107,13 +102,23 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
example, if axis==-1,
`adjustment = lambda shape: (
tf.random.uniform(shape[-1:], 0.93, 1.07),
tf.random.uniform(shape[-1:], -0.1, 0.1))`
will scale the normalized value by up to 7% up or down, then shift the
result by up to 0.1 (with independent scaling and bias for each feature
but shared across all examples), and finally apply gamma and/or beta. If
`None`, no adjustment is applied. Cannot be specified if
virtual_batch_size is specified.
tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
value by up to 7% up or down, then shift the result by up to 0.1
(with independent scaling and bias for each feature but shared
across all examples), and finally apply gamma and/or beta. If
`None`, no adjustment is applied. Cannot be specified if
virtual_batch_size is specified.
name: A string, the name of the layer.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
Internal Covariate Shift:
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
Batch Renormalization - Towards Reducing Minibatch Dependence in
Batch-Normalized Models:
[Ioffe,
2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models)
([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf))
"""
def __init__(self,
@ -197,14 +202,7 @@ def batch_normalization(inputs,
fused=None,
virtual_batch_size=None,
adjustment=None):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift"
Sergey Ioffe, Christian Szegedy
"""Functional interface for the batch normalization layer from_config(Ioffe et al., 2015).
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
@ -232,9 +230,9 @@ def batch_normalization(inputs,
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
next layer is linear (also e.g. `nn.relu`), this can be disabled since the
scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
moving_mean_initializer: Initializer for the moving mean.
@ -242,37 +240,37 @@ def batch_normalization(inputs,
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: An optional projection function to be applied to the `beta`
weight after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The function
must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are
not safe to use when doing asynchronous distributed training.
weight after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
take as input the unprojected variable and must return the projected
variable (which must have the same shape). Constraints are not safe to use
when doing asynchronous distributed training.
gamma_constraint: An optional projection function to be applied to the
`gamma` weight after being updated by an `Optimizer`.
`gamma` weight after being updated by an `Optimizer`.
training: Either a Python boolean, or a TensorFlow boolean scalar tensor
(e.g. a placeholder). Whether to return the output in training mode
(normalized with statistics of the current batch) or in inference mode
(normalized with moving statistics). **NOTE**: make sure to set this
parameter correctly, or else your training/inference will not work
properly.
parameter correctly, or else your training/inference will not work
properly.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: String, the name of the layer.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
reuse: Boolean, whether to reuse the weights of a previous layer by the same
name.
renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra
variables during training. The inference is the same for either value of
this parameter.
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
scalar `Tensors` used to clip the renorm correction. The correction
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
scalar `Tensors` used to clip the renorm correction. The correction `(r,
d)` is used as `corrected_value = normalized_value * r + d`, with `r`
clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
dmax are set to inf, 0, inf, respectively.
renorm_momentum: Momentum used to update the moving means and standard
deviations with renorm. Unlike `momentum`, this affects training
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
deviations with renorm. Unlike `momentum`, this affects training and
should be neither too small (which would add noise) nor too large (which
would give stale estimates). Note that `momentum` is still applied to get
the means and variances for inference.
fused: if `None` or `True`, use a faster, fused implementation if possible.
If `False`, use the system recommended implementation.
virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
@ -287,18 +285,29 @@ def batch_normalization(inputs,
example, if axis==-1,
`adjustment = lambda shape: (
tf.random.uniform(shape[-1:], 0.93, 1.07),
tf.random.uniform(shape[-1:], -0.1, 0.1))`
will scale the normalized value by up to 7% up or down, then shift the
result by up to 0.1 (with independent scaling and bias for each feature
but shared across all examples), and finally apply gamma and/or beta. If
`None`, no adjustment is applied. Cannot be specified if
virtual_batch_size is specified.
tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
value by up to 7% up or down, then shift the result by up to 0.1
(with independent scaling and bias for each feature but shared
across all examples), and finally apply gamma and/or beta. If
`None`, no adjustment is applied. Cannot be specified if
virtual_batch_size is specified.
Returns:
Output tensor.
Raises:
ValueError: if eager execution is enabled.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
Internal Covariate Shift:
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
Batch Renormalization - Towards Reducing Minibatch Dependence in
Batch-Normalized Models:
[Ioffe,
2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models)
([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf))
"""
layer = BatchNormalization(
axis=axis,

View File

@ -260,9 +260,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
Any of the entries of `t_list` that are of type `None` are ignored.
This is the correct way to perform gradient clipping (for example, see
[Pascanu et al., 2012](http://arxiv.org/abs/1211.5063)
([pdf](http://arxiv.org/pdf/1211.5063.pdf))).
This is the correct way to perform gradient clipping (Pascanu et al., 2012).
However, it is slower than `clip_by_norm()` because all the parameters must be
ready before the clipping operation can be performed.
@ -280,6 +278,11 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
Raises:
TypeError: If `t_list` is not a sequence.
References:
On the difficulty of training Recurrent Neural Networks:
[Pascanu et al., 2012](http://proceedings.mlr.press/v28/pascanu13.html)
([pdf](http://proceedings.mlr.press/v28/pascanu13.pdf))
"""
if (not isinstance(t_list, collections_abc.Sequence) or
isinstance(t_list, six.string_types)):

View File

@ -55,12 +55,7 @@ def ctc_loss(labels,
logits=None):
"""Computes the CTC (Connectionist Temporal Classification) Loss.
This op implements the CTC loss as presented in the article:
[A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
This op implements the CTC loss as presented in (Graves et al., 2016).
Input requirements:
@ -154,6 +149,12 @@ def ctc_loss(labels,
Raises:
TypeError: if labels is not a `SparseTensor`.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
# The second, third, etc output tensors contain the gradients. We use it in
# _CTCLossGrad() below.
@ -607,12 +608,7 @@ def ctc_loss_v2(labels,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in the article:
[A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
This op implements the CTC loss as presented in (Graves et al., 2016).
Notes:
@ -648,6 +644,12 @@ def ctc_loss_v2(labels,
Returns:
loss: tensor of shape [batch_size], negative log probabilities.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
if isinstance(labels, sparse_tensor.SparseTensor):
if blank_index is None:
@ -699,21 +701,8 @@ def ctc_loss_dense(labels,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in the article:
[A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
Using the batched forward backward algorithm described in:
[Sim, K. C., Narayanan, A., Bagby, T., Sainath, T. N., & Bacchiani, M.
Improving the efficiency of forward-backward algorithm using batched
computation in TensorFlow.
Automatic Speech Recognition and Understanding Workshop (ASRU),
2017 IEEE (pp. 258-264).
](https://ieeexplore.ieee.org/iel7/8260578/8268903/08268944.pdf)
This op implements the CTC loss as presented in (Graves et al., 2016),
using the batched forward backward algorithm described in (Sim et al., 2017).
Notes:
Significant differences from tf.compat.v1.nn.ctc_loss:
@ -755,6 +744,16 @@ def ctc_loss_dense(labels,
Returns:
loss: tensor of shape [batch_size], negative log probabilities.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
Improving the efficiency of forward-backward algorithm using batched
computation in TensorFlow:
[Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944)
([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf))
"""
with ops.name_scope(name, "ctc_loss_dense",

View File

@ -91,10 +91,8 @@ class Beta(distribution.Distribution):
density.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
#### Examples
@ -149,6 +147,12 @@ class Beta(distribution.Distribution):
grads = tf.gradients(loss, [alpha, beta])
```
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
@deprecation.deprecated(

View File

@ -97,10 +97,8 @@ class Dirichlet(distribution.Distribution):
density.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
#### Examples
@ -155,6 +153,12 @@ class Dirichlet(distribution.Distribution):
grads = tf.gradients(loss, alpha)
```
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
@deprecation.deprecated(

View File

@ -93,10 +93,8 @@ class Gamma(distribution.Distribution):
`rate` is very large. See note in `tf.random.gamma` docstring.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
#### Examples
@ -120,6 +118,11 @@ class Gamma(distribution.Distribution):
grads = tf.gradients(loss, [concentration, rate])
```
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
@deprecation.deprecated(

View File

@ -82,10 +82,8 @@ class StudentT(distribution.Distribution):
t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
#### Examples
@ -139,6 +137,11 @@ class StudentT(distribution.Distribution):
grads = tf.gradients(loss, [df, loc, scale])
```
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
@deprecation.deprecated(

View File

@ -1712,7 +1712,9 @@ def adjust_contrast(images, contrast_factor):
@tf_export('image.adjust_gamma')
def adjust_gamma(image, gamma=1, gain=1):
"""Performs Gamma Correction on the input image.
"""Performs [Gamma Correction](http://en.wikipedia.org/wiki/Gamma_correction).
on the input image.
Also known as Power Law Transform. This function converts the
input images at first to float representation, then transforms them

View File

@ -537,8 +537,6 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
This means that the result of matrix multiplication `v = Au` has `Lth` column
given circular convolution between `h` with the `Lth` column of `u`.
See http://ee.stanford.edu/~gray/toeplitz.pdf
#### Description in terms of the frequency spectrum
There is an equivalent description in terms of the [batch] spectrum `H` and
@ -694,6 +692,11 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
References:
Toeplitz and Circulant Matrices - A Review:
[Gray, 2006](https://www.nowpublishers.com/article/Details/CIT-006)
([pdf](https://ee.stanford.edu/~gray/toeplitz.pdf))
"""
def __init__(self,

View File

@ -14,14 +14,24 @@
# ==============================================================================
"""Gradients for operators defined in linalg_ops.py.
Useful reference for derivative formulas is
An extended collection of matrix derivative results for forward and reverse
mode algorithmic differentiation by Mike Giles:
http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Useful reference for derivative formulas is (Mike Giles, 2008).
A detailed derivation of formulas for backpropagating through spectral layers
(SVD and Eig) by Ionescu, Vantzos & Sminchisescu:
https://arxiv.org/pdf/1509.07838v4.pdf
Ionescu et al. (2015) provide a detailed derivation of formulas for
backpropagating through spectral layers (SVD and Eig).
References:
An extended collection of matrix derivative results for
forward and reverse mode automatic differentiation:
[Mike Giles, 2008]
(https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
Matrix Backpropagation for Deep Networks with Structured Layers
[Ionescu et al., 2015]
(https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
Training Deep Networks with Structured Layers by Matrix Backpropagation:
[Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
([pdf](https://arxiv.org/pdf/1509.07838.pdf))
"""
from __future__ import absolute_import
from __future__ import division
@ -380,7 +390,7 @@ def _MatrixSquareRootGrad(op, grad):
# Used to find Kronecker products within the Sylvester equation
def _KroneckerProduct(b1, b2):
"""Computes the Kronecker product of two batches of square matrices"""
"""Computes the Kronecker product of two batches of square matrices."""
b1_shape = array_ops.shape(b1)
b2_shape = array_ops.shape(b2)
b1_order = b1_shape[-1]

View File

@ -366,7 +366,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
"""Adds a Huber Loss term to the training procedure.
"""Adds a [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss) term to the training procedure.
For each value x in `error=labels-predictions`, the following is calculated:
@ -377,8 +377,6 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
where d is `delta`.
See: https://en.wikipedia.org/wiki/Huber_loss
`weights` acts as a coefficient for the loss. If a scalar is provided, then
the loss is simply scaled by the given value. If `weights` is a tensor of size
`[batch_size]`, then the total loss for each sample of the batch is rescaled
@ -393,8 +391,8 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
delta: `float`, the point where the huber loss function
changes from a quadratic to linear.
delta: `float`, the point where the huber loss function changes from a
quadratic to linear.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
reduction: Type of reduction to apply to loss.

View File

@ -748,7 +748,7 @@ def auc(labels,
epsilon = 1.0e-6
def interpolate_pr_auc(tp, fp, fn):
"""Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
"""Interpolation formula inspired by section 4 of (Davis et al., 2006).
Note here we derive & use a closed formula not present in the paper
- as follows:
@ -775,8 +775,14 @@ def auc(labels,
tp: true positive counts
fp: false positive counts
fn: false negative counts
Returns:
pr_auc: an approximation of the area under the P-R curve.
References:
The Relationship Between Precision-Recall and ROC Curves:
[Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874)
([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf))
"""
dtp = tp[:num_thresholds - 1] - tp[1:]
p = tp + fp

View File

@ -1423,7 +1423,13 @@ def batch_normalization(x,
name: A name for this operation (optional).
Returns:
Normalized, scaled, offset tensor.
the normalized, scaled, offset tensor.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
Internal Covariate Shift:
[Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
inv = math_ops.rsqrt(variance + variance_epsilon)
@ -1448,6 +1454,7 @@ def fused_batch_norm(
name=None):
r"""Batch normalization.
See Source: [Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
(http://arxiv.org/abs/1502.03167).
@ -1472,6 +1479,12 @@ def fused_batch_norm(
Raises:
ValueError: If mean or variance is not None when is_training is True.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
Internal Covariate Shift:
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
x = ops.convert_to_tensor(x, name="input")
scale = ops.convert_to_tensor(scale, name="scale")
@ -1557,6 +1570,12 @@ def batch_norm_with_global_normalization(t=None,
Returns:
A batch-normalized `t`.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
Internal Covariate Shift:
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
t = deprecated_argument_lookup("input", input, "t", t)
m = deprecated_argument_lookup("mean", mean, "m", m)
@ -1599,6 +1618,11 @@ def batch_norm_with_global_normalization_v2(input,
Returns:
A batch-normalized `t`.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
return batch_norm_with_global_normalization(t=input,
m=mean,
@ -1928,12 +1952,6 @@ def nce_loss(weights,
name="nce_loss"):
"""Computes and returns the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for
unnormalized statistical
models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
Also see our [Candidate Sampling Algorithms
Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
A common use case is to use this method for training, and calculate the full
sigmoid loss for evaluation or inference. In this case, you must set
`partition_strategy="div"` for the two losses to be consistent, as in the
@ -1993,9 +2011,9 @@ def nce_loss(weights,
remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
where a sampled class equals one of the target classes. If set to
`True`, this is a "Sampled Logistic" loss instead of NCE, and we are
learning to generate log-odds instead of log probabilities. See
our [Candidate Sampling Algorithms Reference]
(https://www.tensorflow.org/extras/candidate_sampling.pdf).
learning to generate log-odds instead of log probabilities. See
our Candidate Sampling Algorithms Reference
([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
Default is False.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
@ -2004,6 +2022,12 @@ def nce_loss(weights,
Returns:
A `batch_size` 1-D tensor of per-example NCE losses.
References:
Noise-contrastive estimation - A new estimation principle for unnormalized
statistical models:
[Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
"""
logits, labels = _compute_sampled_logits(
weights=weights,
@ -2160,11 +2184,9 @@ def sampled_softmax_loss(weights,
logits=logits)
```
See our [Candidate Sampling Algorithms Reference]
(https://www.tensorflow.org/extras/candidate_sampling.pdf)
Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
See our Candidate Sampling Algorithms Reference
([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
Also see Section 3 of (Jean et al., 2014) for the math.
Args:
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
@ -2195,6 +2217,11 @@ def sampled_softmax_loss(weights,
Returns:
A `batch_size` 1-D tensor of per-example sampled softmax losses.
References:
On Using Very Large Target Vocabulary for Neural Machine Translation:
[Jean et al., 2014]
(https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
([pdf](http://aclweb.org/anthology/P15-1001))
"""
logits, labels = _compute_sampled_logits(
weights=weights,

View File

@ -1449,15 +1449,10 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the amount of computation.
For a description of atrous convolution and how it can be used for dense
feature extraction, please see: [Semantic Image Segmentation with Deep
Convolutional Nets and Fully Connected CRFs](http://arxiv.org/abs/1412.7062).
The same operation is investigated further in [Multi-Scale Context Aggregation
by Dilated Convolutions](http://arxiv.org/abs/1511.07122). Previous works
that effectively use atrous convolution in different ways are, among others,
[OverFeat: Integrated Recognition, Localization and Detection using
Convolutional Networks](http://arxiv.org/abs/1312.6229) and [Fast Image
Scanning with Deep Max-Pooling Convolutional Neural
Networks](http://arxiv.org/abs/1302.1700).
feature extraction, please see: (Chen et al., 2015). The same operation is
investigated further in (Yu et al., 2016). Previous works that effectively
use atrous convolution in different ways are, among others,
(Sermanet et al., 2014) and (Giusti et al., 2013).
Atrous convolution is also closely related to the so-called noble identities
in multi-rate signal processing.
@ -1538,6 +1533,23 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
Raises:
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
References:
Multi-Scale Context Aggregation by Dilated Convolutions:
[Yu et al., 2016](https://arxiv.org/abs/1511.07122)
([pdf](https://arxiv.org/pdf/1511.07122.pdf))
Semantic Image Segmentation with Deep Convolutional Nets and Fully
Connected CRFs:
[Chen et al., 2015](http://arxiv.org/abs/1412.7062)
([pdf](https://arxiv.org/pdf/1412.7062))
OverFeat - Integrated Recognition, Localization and Detection using
Convolutional Networks:
[Sermanet et al., 2014](https://arxiv.org/abs/1312.6229)
([pdf](https://arxiv.org/pdf/1312.6229.pdf))
Fast Image Scanning with Deep Max-Pooling Convolutional Neural Networks:
[Giusti et al., 2013]
(https://ieeexplore.ieee.org/abstract/document/6738831)
([pdf](https://arxiv.org/pdf/1302.1700.pdf))
"""
return convolution(
input=value,
@ -1760,10 +1772,9 @@ def conv1d_transpose(
name=None):
"""The transpose of `conv1d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf),
but is really the transpose (gradient) of `conv1d` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is actually the transpose (gradient) of `conv1d`
rather than an actual deconvolution.
Args:
input: A 3-D `Tensor` of type `float` and shape
@ -1792,6 +1803,13 @@ def conv1d_transpose(
ValueError: If input/output depth does not match `filter`'s shape, if
`output_shape` is not at 3-element vector, if `padding` is other than
`'VALID'` or `'SAME'`, or if `data_format` is invalid.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "conv1d_transpose",
[input, filters, output_shape]) as name:
@ -2149,10 +2167,9 @@ def conv2d_transpose(
dilations=None):
"""The transpose of `conv2d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf),
but is really the transpose (gradient) of `conv2d` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of `conv2d`
rather than an actual deconvolution.
Args:
value: A 4-D `Tensor` of type `float` and shape
@ -2189,6 +2206,13 @@ def conv2d_transpose(
Raises:
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
value = deprecated_argument_lookup("input", input, "value", value)
filter = deprecated_argument_lookup("filters", filters, "filter", filter)
@ -2217,10 +2241,9 @@ def conv2d_transpose_v2(
name=None):
"""The transpose of `conv2d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `conv2d` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of
`atrous_conv2d` rather than an actual deconvolution.
Args:
input: A 4-D `Tensor` of type `float` and shape `[batch, height, width,
@ -2255,6 +2278,13 @@ def conv2d_transpose_v2(
Raises:
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "conv2d_transpose",
[input, filter, output_shape]) as name:
@ -2285,10 +2315,9 @@ def atrous_conv2d_transpose(value,
name=None):
"""The transpose of `atrous_conv2d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf),
but is really the transpose (gradient) of `atrous_conv2d` rather than an
actual deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of
`atrous_conv2d` rather than an actual deconvolution.
Args:
value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
@ -2318,6 +2347,13 @@ def atrous_conv2d_transpose(value,
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
than one, or if the output_shape is not a tensor with 4 elements.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "atrous_conv2d_transpose",
[value, filters, output_shape]) as name:
@ -2481,10 +2517,9 @@ def conv3d_transpose(
dilations=None):
"""The transpose of `conv3d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf),
but is really the transpose (gradient) of `conv3d` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
rather than an actual deconvolution.
Args:
value: A 5-D `Tensor` of type `float` and shape
@ -2518,6 +2553,13 @@ def conv3d_transpose(
Raises:
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
filter = deprecated_argument_lookup("filters", filters, "filter", filter)
value = deprecated_argument_lookup("input", input, "value", value)
@ -2543,10 +2585,9 @@ def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin
name=None):
"""The transpose of `conv3d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `conv2d` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
rather than an actual deconvolution.
Args:
input: A 5-D `Tensor` of type `float` and shape `[batch, height, width,
@ -2577,6 +2618,13 @@ def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin
Returns:
A `Tensor` with the same type as `value`.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "conv3d_transpose",
[input, filter, output_shape]) as name:
@ -2616,10 +2664,9 @@ def conv_transpose(input, # pylint: disable=redefined-builtin
name=None):
"""The transpose of `convolution`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `convolution` rather than an actual
deconvolution.
This operation is sometimes called "deconvolution" after
(Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
rather than an actual deconvolution.
Args:
input: An N+2 dimensional `Tensor` of shape
@ -2657,6 +2704,13 @@ def conv_transpose(input, # pylint: disable=redefined-builtin
Returns:
A `Tensor` with the same type as `value`.
References:
Deconvolutional Networks:
[Zeiler et al., 2010]
(https://ieeexplore.ieee.org/abstract/document/5539957)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "conv_transpose",
[input, filter, output_shape]) as name:
@ -2804,6 +2858,12 @@ def crelu(features, name=None, axis=-1):
Returns:
A `Tensor` with the same type as `features`.
References:
Understanding and Improving Convolutional Neural Networks via Concatenated
Rectified Linear Units:
[Shang et al., 2016](http://proceedings.mlr.press/v48/shang16)
([pdf](http://proceedings.mlr.press/v48/shang16.pdf))
"""
with ops.name_scope(name, "CRelu", [features]) as name:
features = ops.convert_to_tensor(features, name="features")
@ -2821,9 +2881,6 @@ crelu_v2.__doc__ = crelu.__doc__
def relu6(features, name=None):
"""Computes Rectified Linear 6: `min(max(features, 0), 6)`.
Source: [Convolutional Deep Belief Networks on CIFAR-10. A.
Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf)
Args:
features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
`int16`, or `int8`.
@ -2831,6 +2888,11 @@ def relu6(features, name=None):
Returns:
A `Tensor` with the same type as `features`.
References:
Convolutional Deep Belief Networks on CIFAR-10:
Krizhevsky et al., 2010
([pdf](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf))
"""
with ops.name_scope(name, "Relu6", [features]) as name:
features = ops.convert_to_tensor(features, name="features")
@ -2844,7 +2906,6 @@ def leaky_relu(features, alpha=0.2, name=None):
Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models.
AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013]
(https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf).
Args:
features: A `Tensor` representing preactivation values. Must be one of
the following types: `float16`, `float32`, `float64`, `int32`, `int64`.
@ -2853,6 +2914,13 @@ def leaky_relu(features, alpha=0.2, name=None):
Returns:
The activation value.
References:
Rectifier Nonlinearities Improve Neural Network Acoustic Models:
[Maas et al., 2013]
(http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.693.1422)
([pdf]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.693.1422&rep=rep1&type=pdf))
"""
with ops.name_scope(name, "LeakyRelu", [features, alpha]) as name:
features = ops.convert_to_tensor(features, name="features")
@ -4508,9 +4576,6 @@ def fractional_max_pool(value,
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
4. length(row_pooling_sequence) = output_row_length+1
For more details on fractional max pooling, see this paper: [Benjamin Graham,
Fractional Max-Pooling](http://arxiv.org/abs/1412.6071)
Args:
value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
pooling_ratio: A list of `floats` that has length >= 4. Pooling ratio for
@ -4521,8 +4586,7 @@ def fractional_max_pool(value,
ratio on height and width dimensions respectively.
pseudo_random: An optional `bool`. Defaults to `False`. When set to `True`,
generates the pooling sequence in a pseudorandom fashion, otherwise, in a
random fashion. Check paper [Benjamin Graham, Fractional
Max-Pooling](http://arxiv.org/abs/1412.6071) for difference between
random fashion. Check (Graham, 2015) for difference between
pseudorandom and random.
overlapping: An optional `bool`. Defaults to `False`. When set to `True`,
it means when pooling, the values at the boundary of adjacent pooling
@ -4546,6 +4610,11 @@ def fractional_max_pool(value,
`value`.
row_pooling_sequence: A `Tensor` of type `int64`.
col_pooling_sequence: A `Tensor` of type `int64`.
References:
Fractional Max-Pooling:
[Graham, 2015](https://arxiv.org/abs/1412.6071)
([pdf](https://arxiv.org/pdf/1412.6071.pdf))
"""
return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
overlapping, deterministic, seed, seed2,
@ -4587,9 +4656,6 @@ def fractional_max_pool_v2(value,
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
4. length(row_pooling_sequence) = output_row_length+1
For more details on fractional max pooling, see this paper: [Benjamin Graham,
Fractional Max-Pooling](http://arxiv.org/abs/1412.6071)
Args:
value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
pooling_ratio: An int or list of `ints` that has length `1`, `2` or `4`.
@ -4600,8 +4666,7 @@ def fractional_max_pool_v2(value,
1.73 are pooling ratio on height and width dimensions respectively.
pseudo_random: An optional `bool`. Defaults to `False`. When set to `True`,
generates the pooling sequence in a pseudorandom fashion, otherwise, in a
random fashion. Check paper [Benjamin Graham, Fractional
Max-Pooling](http://arxiv.org/abs/1412.6071) for difference between
random fashion. Check paper (Graham, 2015) for difference between
pseudorandom and random.
overlapping: An optional `bool`. Defaults to `False`. When set to `True`,
it means when pooling, the values at the boundary of adjacent pooling
@ -4622,6 +4687,11 @@ def fractional_max_pool_v2(value,
`value`.
row_pooling_sequence: A `Tensor` of type `int64`.
col_pooling_sequence: A `Tensor` of type `int64`.
References:
Fractional Max-Pooling:
[Graham, 2015](https://arxiv.org/abs/1412.6071)
([pdf](https://arxiv.org/pdf/1412.6071.pdf))
"""
pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio")
@ -4666,8 +4736,7 @@ def fractional_avg_pool(value,
ratio on height and width dimensions respectively.
pseudo_random: An optional `bool`. Defaults to `False`. When set to `True`,
generates the pooling sequence in a pseudorandom fashion, otherwise, in a
random fashion. Check paper [Benjamin Graham, Fractional
Max-Pooling](http://arxiv.org/abs/1412.6071) for difference between
random fashion. Check paper (Graham, 2015) for difference between
pseudorandom and random.
overlapping: An optional `bool`. Defaults to `False`. When set to `True`,
it means when pooling, the values at the boundary of adjacent pooling
@ -4691,6 +4760,11 @@ def fractional_avg_pool(value,
`value`.
row_pooling_sequence: A `Tensor` of type `int64`.
col_pooling_sequence: A `Tensor` of type `int64`.
References:
Fractional Max-Pooling:
[Graham, 2015](https://arxiv.org/abs/1412.6071)
([pdf](https://arxiv.org/pdf/1412.6071.pdf))
"""
return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
overlapping, deterministic, seed, seed2,
@ -4721,8 +4795,7 @@ def fractional_avg_pool_v2(value,
ratio on height and width dimensions respectively.
pseudo_random: An optional `bool`. Defaults to `False`. When set to `True`,
generates the pooling sequence in a pseudorandom fashion, otherwise, in a
random fashion. Check paper [Benjamin Graham, Fractional
Max-Pooling](http://arxiv.org/abs/1412.6071) for difference between
random fashion. Check paper (Graham, 2015) for difference between
pseudorandom and random.
overlapping: An optional `bool`. Defaults to `False`. When set to `True`,
it means when pooling, the values at the boundary of adjacent pooling
@ -4743,6 +4816,11 @@ def fractional_avg_pool_v2(value,
`value`.
row_pooling_sequence: A `Tensor` of type `int64`.
col_pooling_sequence: A `Tensor` of type `int64`.
References:
Fractional Max-Pooling:
[Graham, 2015](https://arxiv.org/abs/1412.6071)
([pdf](https://arxiv.org/pdf/1412.6071.pdf))
"""
if seed == 0:
return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,

View File

@ -36,8 +36,8 @@ def add_leading_unit_dimensions(x, num_dimensions):
def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name
"""Returns the gradient of a Gamma sample w.r.t. alpha.
The gradient is computed using implicit differentiation, see
"Implicit Reparameterization Gradients" (https://arxiv.org/abs/1805.08498).
The gradient is computed using implicit differentiation
(Figurnov et al., 2018).
Args:
op: A `RandomGamma` operation. We assume that the inputs to the operation
@ -46,7 +46,14 @@ def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name
`op.outputs[0]`.
Returns:
A `Tensor` with derivatives `dloss / dalpha`
A `Tensor` with derivatives `dloss / dalpha`.
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
shape = op.inputs[0]
alpha = op.inputs[1]

View File

@ -484,10 +484,8 @@ def random_gamma(shape,
`alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
The samples are differentiable w.r.t. alpha and beta.
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
Example:
@ -533,6 +531,13 @@ def random_gamma(shape,
samples: a `Tensor` of shape
`tf.concat([shape, tf.shape(alpha + beta)], axis=0)` with values of type
`dtype`.
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
with ops.name_scope(name, "random_gamma", [shape, alpha, beta]):
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)

View File

@ -482,9 +482,7 @@ class BasicRNNCell(LayerRNNCell):
@tf_export(v1=["nn.rnn_cell.GRUCell"])
class GRUCell(LayerRNNCell):
"""Gated Recurrent Unit cell (cf.
http://arxiv.org/abs/1406.1078).
"""Gated Recurrent Unit cell.
Note that this cell is not optimized for performance. Please use
`tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
@ -505,6 +503,13 @@ class GRUCell(LayerRNNCell):
the first input). Required when `build` is called before `call`.
**kwargs: Dict, keyword named properties for common layer attributes, like
`trainable` etc when constructing the cell from configs of get_config().
References:
Learning Phrase Representations using RNN Encoder Decoder for Statistical
Machine Translation:
[Cho et al., 2014]
(https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179)
([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf))
"""
@deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell,"
@ -638,7 +643,7 @@ class BasicLSTMCell(LayerRNNCell):
Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
The implementation is based on
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
@ -804,20 +809,8 @@ class BasicLSTMCell(LayerRNNCell):
class LSTMCell(LayerRNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
The default non-peephole implementation is based on:
https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
"Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
https://research.google.com/pubs/archive/43905.pdf
Hasim Sak, Andrew Senior, and Francoise Beaufays.
"Long short-term memory recurrent neural network architectures for
large scale acoustic modeling." INTERSPEECH, 2014.
The default non-peephole implementation is based on (Gers et al., 1999).
The peephole implementation is based on (Sak et al., 2014).
The class uses optional peep-hole connections, optional cell clipping, and
an optional projection layer.
@ -826,6 +819,21 @@ class LSTMCell(LayerRNNCell):
`tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
`tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
better performance on CPU.
References:
Long short-term memory recurrent neural network architectures for large
scale acoustic modeling:
[Sak et al., 2014]
(https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html)
([pdf]
(https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf))
Learning to forget:
[Gers et al., 1999]
(http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218)
([pdf](https://arxiv.org/pdf/1409.2329.pdf))
Long Short-Term Memory:
[Hochreiter et al., 1997]
(https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735)
([pdf](http://ml.jku.at/publications/older/3504.pdf))
"""
@deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell,"

View File

@ -29,8 +29,10 @@ from tensorflow.python.util.tf_export import tf_export
class AdadeltaOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adadelta algorithm.
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
References:
ADADELTA - An Adaptive Learning Rate Method:
[Zeiler, 2012](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,

View File

@ -32,9 +32,10 @@ from tensorflow.python.util.tf_export import tf_export
class AdagradOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adagrad algorithm.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
References:
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
:[Duchi et al., 2011](http://jmlr.org/papers/v12/duchi11a.html)
([pdf](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf))
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,

View File

@ -30,8 +30,6 @@ from tensorflow.python.util.tf_export import tf_export
class AdagradDAOptimizer(optimizer.Optimizer):
"""Adagrad Dual Averaging algorithm for sparse linear models.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
This optimizer takes care of regularization of unseen features in a mini batch
by updating them when they are seen with a closed form update rule that is
equivalent to having updated them on every mini-batch.
@ -40,6 +38,11 @@ class AdagradDAOptimizer(optimizer.Optimizer):
trained model. This optimizer only guarantees sparsity for linear models. Be
careful when using AdagradDA for deep networks as it will require careful
initialization of the gradient accumulators for it to train.
References:
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
:[Duchi et al., 2011](http://jmlr.org/papers/v12/duchi11a.html)
([pdf](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf))
"""
def __init__(self,

View File

@ -32,8 +32,10 @@ from tensorflow.python.util.tf_export import tf_export
class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
References:
Adam - A Method for Stochastic Optimization:
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
"""
def __init__(self,

View File

@ -29,11 +29,14 @@ from tensorflow.python.util.tf_export import tf_export
class FtrlOptimizer(optimizer.Optimizer):
"""Optimizer that implements the FTRL algorithm.
See this [paper](
https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
This version has support for both online L2 (the L2 penalty given in the paper
above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
loss function).
This version has support for both online L2 (McMahan et al., 2013) and
shrinkage-type L2, which is the addition of an L2 penalty
to the loss function.
References:
Ad-click prediction:
[McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200)
([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526))
"""
def __init__(self,
@ -53,8 +56,7 @@ class FtrlOptimizer(optimizer.Optimizer):
learning_rate: A float value or a constant float `Tensor`.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero for
a fixed learning rate. See section 3.1 in the
[paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
a fixed learning rate. See section 3.1 in (McMahan et al., 2013).
initial_accumulator_value: The starting value for accumulators.
Only zero or positive values are allowed.
l1_regularization_strength: A float value, must be greater than or
@ -84,6 +86,11 @@ class FtrlOptimizer(optimizer.Optimizer):
Raises:
ValueError: If one of the arguments is invalid.
References:
Ad-click prediction:
[McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200)
([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526))
"""
super(FtrlOptimizer, self).__init__(use_locking, name)

View File

@ -455,9 +455,6 @@ def inverse_time_decay(learning_rate,
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
"""Applies cosine decay to the learning rate.
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
When training a model, it is often recommended to lower the learning rate as
the training progresses. This function applies a cosine decay function
to a provided initial learning rate. It requires a `global_step` value to
@ -495,6 +492,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
Raises:
ValueError: if `global_step` is not supplied.
References:
Stochastic Gradient Descent with Warm Restarts:
[Loshchilov et al., 2017]
(https://openreview.net/forum?id=Skq89Scxx&noteId=Skq89Scxx)
([pdf](https://openreview.net/pdf?id=Skq89Scxx))
@compatibility(eager)
When eager execution is enabled, this function returns a function which in
turn returns the decayed learning rate Tensor. This can be useful for changing
@ -521,9 +524,6 @@ def cosine_decay_restarts(learning_rate,
name=None):
"""Applies cosine decay with restarts to the learning rate.
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
When training a model, it is often recommended to lower the learning rate as
the training progresses. This function applies a cosine decay function with
restarts to a provided initial learning rate. It requires a `global_step`
@ -564,6 +564,12 @@ def cosine_decay_restarts(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
References:
Stochastic Gradient Descent with Warm Restarts:
[Loshchilov et al., 2017]
(https://openreview.net/forum?id=Skq89Scxx&noteId=Skq89Scxx)
([pdf](https://openreview.net/pdf?id=Skq89Scxx))
@compatibility(eager)
When eager execution is enabled, this function returns a function which in
turn returns the decayed learning rate Tensor. This can be useful for changing
@ -595,13 +601,6 @@ def linear_cosine_decay(learning_rate,
name=None):
"""Applies linear cosine decay to the learning rate.
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
https://arxiv.org/abs/1709.07417
For the idea of warm starts here controlled by `num_periods`,
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
Note that linear cosine decay is more aggressive than cosine decay and
larger initial learning rates can typically be used.
@ -647,6 +646,15 @@ def linear_cosine_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
References:
Neural Optimizer Search with Reinforcement Learning:
[Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html)
([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf))
Stochastic Gradient Descent with Warm Restarts:
[Loshchilov et al., 2017]
(https://openreview.net/forum?id=Skq89Scxx&noteId=Skq89Scxx)
([pdf](https://openreview.net/pdf?id=Skq89Scxx))
@compatibility(eager)
When eager execution is enabled, this function returns a function which in
turn returns the decayed learning rate Tensor. This can be useful for changing
@ -680,13 +688,6 @@ def noisy_linear_cosine_decay(learning_rate,
name=None):
"""Applies noisy linear cosine decay to the learning rate.
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
https://arxiv.org/abs/1709.07417
For the idea of warm starts here controlled by `num_periods`,
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
with Warm Restarts. https://arxiv.org/abs/1608.03983
Note that linear cosine decay is more aggressive than cosine decay and
larger initial learning rates can typically be used.
@ -738,6 +739,15 @@ def noisy_linear_cosine_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
References:
Neural Optimizer Search with Reinforcement Learning:
[Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html)
([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf))
Stochastic Gradient Descent with Warm Restarts:
[Loshchilov et al., 2017]
(https://openreview.net/forum?id=Skq89Scxx&noteId=Skq89Scxx)
([pdf](https://openreview.net/pdf?id=Skq89Scxx))
@compatibility(eager)
When eager execution is enabled, this function returns a function which in
turn returns the decayed learning rate Tensor. This can be useful for changing

View File

@ -54,17 +54,22 @@ class MomentumOptimizer(optimizer.Optimizer):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Momentum".
use_nesterov: If `True` use Nesterov Momentum.
See [Sutskever et al., 2013](
http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
See (Sutskever et al., 2013).
This implementation always computes gradients at the value of the
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
variable(s) track the values called `theta_t + mu*v_t` in the paper.
This implementation is an approximation of the original formula, valid
for high values of momentum. It will compute the "adjusted gradient"
in NAG by assuming that the new gradient will be estimated by the
current average gradient plus the product of momentum and the change
This implementation is an approximation of the original formula, valid
for high values of momentum. It will compute the "adjusted gradient"
in NAG by assuming that the new gradient will be estimated by the
current average gradient plus the product of momentum and the change
in the average gradient.
References:
On the importance of initialization and momentum in deep learning:
[Sutskever et al., 2013]
(http://proceedings.mlr.press/v28/sutskever13.html)
([pdf](http://proceedings.mlr.press/v28/sutskever13.pdf))
@compatibility(eager)
When eager execution is enabled, `learning_rate` and `momentum` can each be
a callable that takes no arguments and returns the actual value to use. This

View File

@ -46,8 +46,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
`zero_debias` optionally enables scaling by the mathematically correct
debiasing factor of
1 - decay ** num_updates
See `ADAM: A Method for Stochastic Optimization` Section 3 for more details
(https://arxiv.org/abs/1412.6980).
See Section 3 of (Kingma et al., 2015) for more details.
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
@ -72,12 +71,17 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
value: A tensor with the same shape as 'variable'.
decay: A float Tensor or float value. The moving average decay.
zero_debias: A python bool. If true, assume the variable is 0-initialized
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
and unbias it, as in (Kingma et al., 2015). See docstring in
`_zero_debias` for more details.
name: Optional name of the returned operation.
Returns:
A tensor which if evaluated will compute and return the new moving average.
References:
Adam - A Method for Stochastic Optimization:
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
"""
with ops.name_scope(name, "AssignMovingAvg",
@ -180,7 +184,7 @@ def _zero_debias(strategy, unbiased_var, value, decay):
All exponential moving averages initialized with Tensors are initialized to 0,
and therefore are biased to 0. Variables initialized to 0 and used as EMAs are
similarly biased. This function creates the debias updated amount according to
a scale factor, as in https://arxiv.org/abs/1412.6980.
a scale factor, as in (Kingma et al., 2015).
To demonstrate the bias the results from 0-initialization, take an EMA that
was initialized to `0` with decay `b`. After `t` timesteps of seeing the
@ -204,7 +208,14 @@ def _zero_debias(strategy, unbiased_var, value, decay):
decay: A Tensor representing `1-decay` for the EMA.
Returns:
Operation which updates unbiased_var to the debiased moving average value.
The amount that the unbiased variable should be updated. Computing this
tensor will also update the shadow variables appropriately.
References:
Adam - A Method for Stochastic Optimization:
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
"""
with variable_scope.variable_scope(
unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]):

View File

@ -31,7 +31,13 @@ class ProximalAdagradOptimizer(optimizer.Optimizer):
# pylint: disable=line-too-long
"""Optimizer that implements the Proximal Adagrad algorithm.
See this [paper](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf).
References:
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
[Duchi et al., 2011](http://jmlr.org/papers/v12/duchi11a.html)
([pdf](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf))
Efficient Learning using Forward-Backward Splitting:
[Duchi et al., 2009](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting)
([pdf](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf))
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,

View File

@ -32,7 +32,10 @@ class ProximalGradientDescentOptimizer(optimizer.Optimizer):
# pylint: disable=line-too-long
"""Optimizer that implements the proximal gradient descent algorithm.
See this [paper](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf).
References:
Efficient Learning using Forward-Backward Splitting:
[Duchi et al., 2009](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting)
([pdf](http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf))
"""
def __init__(self, learning_rate, l1_regularization_strength=0.0,

View File

@ -52,10 +52,14 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["train.RMSPropOptimizer"])
class RMSPropOptimizer(optimizer.Optimizer):
"""Optimizer that implements the RMSProp algorithm.
"""Optimizer that implements the RMSProp algorithm (Tielemans et al.
See the
[paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
2012).
References:
Coursera slide 29:
Hinton, 2012
([pdf](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf))
"""
def __init__(self,