Improve activation functions docstrings.

PiperOrigin-RevId: 304248027
Change-Id: I55f0a21d09d7169285a15723bdbaf952ab9b78ae
This commit is contained in:
Francois Chollet 2020-04-01 13:12:56 -07:00 committed by TensorFlower Gardener
parent fb5e9688cb
commit b2ce99d186

View File

@ -53,7 +53,9 @@ def softmax(x, axis=-1):
layer of a classification network because the result could be interpreted as
a probability distribution.
The softmax of each vector x is calculated by `exp(x)/tf.reduce_sum(exp(x))`.
The softmax of each vector x is computed as
`exp(x) / tf.reduce_sum(exp(x))`.
The input values in are the log-odds of the resulting probability.
Arguments:
@ -92,8 +94,7 @@ def elu(x, alpha=1.0):
`alpha * (exp(x)-1)` if `x < 0`.
Reference:
- [Fast and Accurate Deep Network Learning by Exponential
Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
- [Clevert et al. 2016](https://arxiv.org/abs/1511.07289)
"""
return K.elu(x, alpha)
@ -102,40 +103,36 @@ def elu(x, alpha=1.0):
def selu(x):
"""Scaled Exponential Linear Unit (SELU).
The Scaled Exponential Linear Unit (SELU) activation function is:
`scale * x` if `x > 0` and `scale * alpha * (exp(x) - 1)` if `x < 0`
The Scaled Exponential Linear Unit (SELU) activation function is defined as:
- `if x > 0: return scale * x`
- `if x < 0: return scale * alpha * (exp(x) - 1)`
where `alpha` and `scale` are pre-defined constants
(`alpha = 1.67326324`
and `scale = 1.05070098`).
The SELU activation function multiplies `scale` > 1 with the
`[elu](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/activations/elu)`
(Exponential Linear Unit (ELU)) to ensure a slope larger than one
for positive net inputs.
(`alpha=1.67326324` and `scale=1.05070098`).
Basically, the SELU activation function multiplies `scale` (> 1) with the
output of the `tf.keras.activations.elu` function to ensure a slope larger
than one for positive inputs.
The values of `alpha` and `scale` are
chosen so that the mean and variance of the inputs are preserved
between two consecutive layers as long as the weights are initialized
correctly (see [`lecun_normal` initialization]
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal))
and the number of inputs is "large enough"
(see references for more information).
![]https://cdn-images-1.medium.com/max/1600/1*m0e8lZU_Zrkh4ESfQkY2Pw.png
(Courtesy: Blog on Towards DataScience at
https://towardsdatascience.com/selu-make-fnns-great-again-snn-8d61526802a9)
correctly (see `tf.keras.initializers.LecunNormal` initializer)
and the number of input units is "large enough"
(see reference paper for more information).
Example Usage:
>>> n_classes = 10 #10-class problem
>>> from tensorflow.python.keras.layers import Dense
>>> num_classes = 10 # 10-class problem
>>> model = tf.keras.Sequential()
>>> model.add(Dense(64, kernel_initializer='lecun_normal',
... activation='selu', input_shape=(28, 28, 1)))
>>> model.add(Dense(32, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(16, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(n_classes, activation='softmax'))
>>> model.add(tf.keras.layers.Dense(64, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(32, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(16, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
Arguments:
x: A tensor or variable to compute the activation function for.
@ -143,22 +140,21 @@ def selu(x):
Returns:
The scaled exponential unit activation: `scale * elu(x, alpha)`.
# Note
- To be used together with the initialization "[lecun_normal]
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal)".
- To be used together with the dropout variant "[AlphaDropout]
(https://www.tensorflow.org/api_docs/python/tf/keras/layers/AlphaDropout)".
Notes:
- To be used together with the
`tf.keras.initializers.LecunNormal` initializer.
- To be used together with the dropout variant
`tf.keras.layers.AlphaDropout` (not regular dropout).
References:
[Self-Normalizing Neural Networks (Klambauer et al, 2017)]
(https://arxiv.org/abs/1706.02515)
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
"""
return nn.selu(x)
@keras_export('keras.activations.softplus')
def softplus(x):
"""Softplus activation function.
"""Softplus activation function, `softplus(x) = log(exp(x) + 1)`.
Arguments:
x: Input tensor.
@ -171,7 +167,7 @@ def softplus(x):
@keras_export('keras.activations.softsign')
def softsign(x):
"""Softsign activation function.
"""Softsign activation function, `softsign(x) = x / (abs(x) + 1)`.
Arguments:
x: Input tensor.
@ -190,7 +186,10 @@ def swish(x):
x: Input tensor.
Returns:
The swish activation applied to `x`.
The swish activation applied to `x` (see reference paper for details).
Reference:
- [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941)
"""
return nn.swish(x)
@ -244,8 +243,7 @@ def tanh(x):
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.tanh(a)
>>> b.numpy()
array([-0.9950547, -0.7615942, 0. , 0.7615942, 0.9950547],
dtype=float32)
array([-0.9950547, -0.7615942, 0., 0.7615942, 0.9950547], dtype=float32)
Arguments:
x: Input tensor.
@ -259,12 +257,10 @@ def tanh(x):
@keras_export('keras.activations.sigmoid')
def sigmoid(x):
"""Sigmoid activation function.
"""Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`.
Applies the sigmoid activation function. The sigmoid function is defined as
1 divided by (1 + exp(-x)). It's curve is like an "S" and is like a smoothed
version of the Heaviside (Unit Step Function) function. For small values
(<-5) the sigmoid returns a value close to zero and for larger values (>5)
Applies the sigmoid activation function. For small values (<-5),
`sigmoid` returns a value close to zero, and for large values (>5)
the result of the function gets close to 1.
Sigmoid is equivalent to a 2-element Softmax, where the second element is
@ -274,15 +270,14 @@ def sigmoid(x):
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.sigmoid(a)
>>> b.numpy() >= 0.0
>>> b.numpy() >= 0
array([ True, True, True, True, True])
Arguments:
x: Input tensor.
Returns:
Tensor with the sigmoid activation: `(1.0 / (1.0 + exp(-x)))`.
Tensor will be of same shape and dtype of input `x`.
Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
"""
return nn.sigmoid(x)
@ -296,15 +291,13 @@ def exponential(x):
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.exponential(a)
>>> b.numpy()
array([ 0.04978707, 0.36787945, 1. , 2.7182817 , 20.085537 ],
dtype=float32)
array([0.04978707, 0.36787945, 1., 2.7182817 , 20.085537], dtype=float32)
Arguments:
x: Input tensor.
Returns:
Tensor with exponential activation: `exp(x)`. Tensor will be of same
shape and dtype of input `x`.
Tensor with exponential activation: `exp(x)`.
"""
return math_ops.exp(x)
@ -313,7 +306,7 @@ def exponential(x):
def hard_sigmoid(x):
"""Hard sigmoid activation function.
Faster to compute than sigmoid activation.
A faster approximation of the sigmoid activation.
For example:
@ -326,18 +319,18 @@ def hard_sigmoid(x):
x: Input tensor.
Returns:
The hard sigmoid activation:
The hard sigmoid activation, defined as:
- `0` if `x < -2.5`
- `1` if `x > 2.5`
- `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
- `if x < -2.5: return 0`
- `if x > 2.5: return 1`
- `if -2.5 <= x <= 2.5: return 0.2 * x + 0.5`
"""
return K.hard_sigmoid(x)
@keras_export('keras.activations.linear')
def linear(x):
"""Linear activation function.
"""Linear activation function (pass-through).
For example:
@ -350,17 +343,17 @@ def linear(x):
x: Input tensor.
Returns:
the input unmodified.
The input, unmodified.
"""
return x
@keras_export('keras.activations.serialize')
def serialize(activation):
"""Returns name attribute (`__name__`) of function.
"""Returns the string identifier of an activation function.
Arguments:
activation : Function
activation : Function object.
Returns:
String denoting the name attribute of the input function
@ -387,13 +380,13 @@ def serialize(activation):
@keras_export('keras.activations.deserialize')
def deserialize(name, custom_objects=None):
"""Returns activation function denoted by input string.
"""Returns activation function given a string identifier.
Arguments:
x : String
x : String identifier.
Returns:
TensorFlow Activation function denoted by input string.
Corresponding activation function.
For example:
@ -408,8 +401,8 @@ def deserialize(name, custom_objects=None):
Args:
name: The name of the activation function.
custom_objects: A {name:value} dictionary for activations not build into
keras.
custom_objects: Optional `{function_name: function_obj}`
dictionary listing user-provided activation functions.
Raises:
ValueError: `Unknown activation function` if the input string does not
@ -430,9 +423,7 @@ def get(identifier):
identifier: Function or string
Returns:
Activation function denoted by input:
- `Linear activation function` if input is `None`.
- Function corresponding to the input string or input function.
Function corresponding to the input string or input function.
For example: