Merge pull request #24791 from facaiy:ENH/better_leaky_relu
PiperOrigin-RevId: 251168734
This commit is contained in:
commit
e1c98eeb8f
@ -100,7 +100,9 @@ struct LeakyRelu {
|
|||||||
// activations: same shape as "features".
|
// activations: same shape as "features".
|
||||||
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
|
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
|
||||||
T alpha, typename TTypes<T>::Tensor activations) {
|
T alpha, typename TTypes<T>::Tensor activations) {
|
||||||
activations.device(d) = features.cwiseMax(features * alpha);
|
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
|
||||||
|
activations.device(d) =
|
||||||
|
(features > static_cast<T>(0)).select(features, features * alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -406,6 +406,20 @@ class LeakyReluTest(test.TestCase):
|
|||||||
self.evaluate(optimizer.minimize(loss))
|
self.evaluate(optimizer.minimize(loss))
|
||||||
self.assertAllClose(x.read_value(), -99.9)
|
self.assertAllClose(x.read_value(), -99.9)
|
||||||
|
|
||||||
|
def testUnexpectedAlphaValue(self):
|
||||||
|
self.assertAllClose(
|
||||||
|
np.array([[-9.0, 0.7, -5.0, 0.3, -0.1], [0.1, -3.0, 0.5, -27.0, 0.9]]),
|
||||||
|
nn_ops.leaky_relu(
|
||||||
|
np.array([[-0.9, 0.7, -0.5, 0.3, -0.01],
|
||||||
|
[0.1, -0.3, 0.5, -2.7, 0.9]]),
|
||||||
|
alpha=10))
|
||||||
|
self.assertAllClose(
|
||||||
|
np.array([[9.0, 0.7, 5.0, 0.3, 0.1], [0.1, 3.0, 0.5, 27.0, 0.9]]),
|
||||||
|
nn_ops.leaky_relu(
|
||||||
|
np.array([[-0.9, 0.7, -0.5, 0.3, -0.01],
|
||||||
|
[0.1, -0.3, 0.5, -2.7, 0.9]]),
|
||||||
|
alpha=-10))
|
||||||
|
|
||||||
|
|
||||||
class EluTest(test.TestCase):
|
class EluTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ import numbers
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
@ -2759,12 +2758,9 @@ def leaky_relu(features, alpha=0.2, name=None):
|
|||||||
features = ops.convert_to_tensor(features, name="features")
|
features = ops.convert_to_tensor(features, name="features")
|
||||||
if features.dtype.is_integer:
|
if features.dtype.is_integer:
|
||||||
features = math_ops.cast(features, dtypes.float32)
|
features = math_ops.cast(features, dtypes.float32)
|
||||||
if compat.forward_compatible(2018, 11, 1):
|
|
||||||
if isinstance(alpha, np.ndarray):
|
if isinstance(alpha, np.ndarray):
|
||||||
alpha = alpha.item()
|
alpha = alpha.item()
|
||||||
return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
|
return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
|
||||||
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
|
|
||||||
return math_ops.maximum(alpha * features, features, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_outer_dims(logits):
|
def _flatten_outer_dims(logits):
|
||||||
|
Loading…
Reference in New Issue
Block a user