Merge pull request from asmitapoddar:patch-6

PiperOrigin-RevId: 271580224
This commit is contained in:
TensorFlower Gardener 2019-09-27 13:21:39 -07:00
commit 68c22de798
2 changed files with 153 additions and 34 deletions
tensorflow
python/keras
tools/docs

View File

@ -42,12 +42,19 @@ _TF_ACTIVATIONS_V2 = {
@keras_export('keras.activations.softmax')
def softmax(x, axis=-1):
"""The softmax activation function transforms the outputs so that all values are in
"""Softmax converts a real vector to a vector of categorical probabilities.
range (0, 1) and sum to 1. It is often used as the activation for the last
The the elements of the output vector are in range (0, 1) and sum to 1.
Each vector is handled independently. The `axis` argument sets which axis
of the input the finction is applied along.
Softmax is often used as the activation for the last
layer of a classification network because the result could be interpreted as
a probability distribution. The softmax of x is calculated by
exp(x)/tf.reduce_sum(exp(x)).
a probability distribution.
The softmax of each vector x is calculated by `exp(x)/tf.reduce_sum(exp(x))`.
The input values in are the log-odds of the resulting probability.
Arguments:
x : Input tensor.
@ -118,15 +125,17 @@ def selu(x):
https://towardsdatascience.com/selu-make-fnns-great-again-snn-8d61526802a9)
Example Usage:
```python3
n_classes = 10 #10-class problem
model = models.Sequential()
model.add(Dense(64, kernel_initializer='lecun_normal', activation='selu',
input_shape=(28, 28, 1))))
model.add(Dense(32, kernel_initializer='lecun_normal', activation='selu'))
model.add(Dense(16, kernel_initializer='lecun_normal', activation='selu'))
model.add(Dense(n_classes, activation='softmax'))
```
>>> n_classes = 10 #10-class problem
>>> from tensorflow.python.keras.layers import Dense
>>> model = tf.keras.Sequential()
>>> model.add(Dense(64, kernel_initializer='lecun_normal',
... activation='selu', input_shape=(28, 28, 1)))
>>> model.add(Dense(32, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(16, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(n_classes, activation='softmax'))
Arguments:
x: A tensor or variable to compute the activation function for.
@ -200,48 +209,53 @@ def relu(x, alpha=0., max_value=None, threshold=0):
@keras_export('keras.activations.tanh')
def tanh(x):
"""Hyperbolic Tangent (tanh) activation function.
"""Hyperbolic tangent activation function.
For example:
```python
# Constant 1-D tensor populated with value list.
a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
b = tf.keras.activations.tanh(a) #[-0.9950547,-0.7615942,
0.,0.7615942,0.9950547]
```
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.tanh(a)
>>> b.numpy()
array([-0.9950547, -0.7615942, 0. , 0.7615942, 0.9950547],
dtype=float32)
Arguments:
x: Input tensor.
Returns:
A tensor of same shape and dtype of input `x`.
The tanh activation: `tanh(x) = sinh(x)/cosh(x) = ((exp(x) -
exp(-x))/(exp(x) + exp(-x)))`.
Tensor of same shape and dtype of input `x`, with tanh activation:
`tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`.
"""
return nn.tanh(x)
@keras_export('keras.activations.sigmoid')
def sigmoid(x):
"""Sigmoid.
"""Sigmoid activation function.
Applies the sigmoid activation function. The sigmoid function is defined as
1 divided by (1 + exp(-x)). It's curve is like an "S" and is like a smoothed
version of the Heaviside (Unit Step Function) function. For small values
(<-5) the sigmoid returns a value close to zero and for larger values (>5)
the result of the function gets close to 1.
Arguments:
x: A tensor or variable.
Returns:
A tensor.
Sigmoid activation function.
Sigmoid is equivalent to a 2-element Softmax, where the second element is
assumed to be zero.
For example:
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.sigmoid(a)
>>> b.numpy()
array([0. , 0.26894143, 0.5 , 0.7310586 , 1. ],
dtype=float32)
Arguments:
x: Input tensor.
Returns:
The sigmoid activation: `(1.0 / (1.0 + exp(-x)))`.
Tensor with the sigmoid activation: `(1.0 / (1.0 + exp(-x)))`.
Tensor will be of same shape and dtype of input `x`.
"""
return nn.sigmoid(x)
@ -250,11 +264,20 @@ def sigmoid(x):
def exponential(x):
"""Exponential activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.exponential(a)
>>> b.numpy()
array([ 0.04978707, 0.36787945, 1. , 2.7182817 , 20.085537 ],
dtype=float32)
Arguments:
x: Input tensor.
Returns:
The exponential activation: `exp(x)`.
Tensor with exponential activation: `exp(x)`. Tensor will be of same
shape and dtype of input `x`.
"""
return math_ops.exp(x)
@ -265,11 +288,19 @@ def hard_sigmoid(x):
Faster to compute than sigmoid activation.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.hard_sigmoid(a)
>>> b.numpy()
array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32)
Arguments:
x: Input tensor.
Returns:
Hard sigmoid activation:
The hard sigmoid activation:
- `0` if `x < -2.5`
- `1` if `x > 2.5`
- `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
@ -281,17 +312,46 @@ def hard_sigmoid(x):
def linear(x):
"""Linear activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.linear(a)
>>> b.numpy()
array([-3., -1., 0., 1., 3.], dtype=float32)
Arguments:
x: Input tensor.
Returns:
The linear activation: `x`.
the input unmodified.
"""
return x
@keras_export('keras.activations.serialize')
def serialize(activation):
"""Returns name attribute (`__name__`) of function.
Arguments:
activation : Function
Returns:
String denoting the name attribute of the input function
For example:
>>> tf.keras.activations.serialize(tf.keras.activations.tanh)
'tanh'
>>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
'sigmoid'
>>> tf.keras.activations.serialize('abcd')
Traceback (most recent call last):
...
ValueError: ('Cannot serialize', 'abcd')
Raises:
ValueError: The input function is not a valid one.
"""
if (hasattr(activation, '__name__') and
activation.__name__ in _TF_ACTIVATIONS_V2):
return _TF_ACTIVATIONS_V2[activation.__name__]
@ -300,6 +360,33 @@ def serialize(activation):
@keras_export('keras.activations.deserialize')
def deserialize(name, custom_objects=None):
"""Returns activation function denoted by input string.
Arguments:
x : String
Returns:
Tensorlow Activation function denoted by input string.
For example:
>>> tf.keras.activations.deserialize('linear')
<function linear at 0x1239596a8>
>>> tf.keras.activations.deserialize('sigmoid')
<function sigmoid at 0x123959510>
>>> tf.keras.activations.deserialize('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Args:
name: The name of the actiuvation function.
custom_objects: A {name:value} dictionary for activations not build into
keras.
Raises:
ValueError: `Unknown activation function` if the input string does not
denote any defined Tensorflow activation function.
"""
return deserialize_keras_object(
name,
module_objects=globals(),
@ -309,6 +396,35 @@ def deserialize(name, custom_objects=None):
@keras_export('keras.activations.get')
def get(identifier):
"""Returns function.
Arguments:
identifier: Function or string
Returns:
Activation function denoted by input:
- `Linear activation function` if input is `None`.
- Function corresponding to the input string or input function.
For example:
>>> tf.keras.activations.get('softmax')
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(tf.keras.activations.softmax)
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(None)
<function linear at 0x1239596a8>
>>> tf.keras.activations.get(abs)
<built-in function abs>
>>> tf.keras.activations.get('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Raises:
ValueError: Input is an unknown function or string, i.e., the input does
not denote any defined function.
"""
if identifier is None:
return linear
if isinstance(identifier, six.string_types):

View File

@ -121,12 +121,15 @@ class CustomOutputChecker(doctest.OutputChecker):
This allows it to be customized before they are compared.
"""
ID_RE = re.compile(r'\bid=(\d+)\b')
ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>')
def check_output(self, want, got, optionflags):
# Replace tf.Tensor's id with ellipsis(...) because tensor's id can change
# on each execution. Users may forget to use ellipsis while writing
# examples in docstrings, so replacing the id with `...` makes it safe.
want = re.sub(r'\bid=(\d+)\b', r'id=...', want)
want = self.ID_RE.sub('id=...', want)
want = self.ADDRESS_RE.sub('at ...>', want)
return doctest.OutputChecker.check_output(self, want, got, optionflags)
_MESSAGE = textwrap.dedent("""\n