Merge pull request #24791 from facaiy:ENH/better_leaky_relu

PiperOrigin-RevId: 251168734
This commit is contained in:
TensorFlower Gardener 2019-06-02 23:05:44 -07:00
commit e1c98eeb8f
3 changed files with 20 additions and 8 deletions

View File

@ -100,7 +100,9 @@ struct LeakyRelu {
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
T alpha, typename TTypes<T>::Tensor activations) {
activations.device(d) = features.cwiseMax(features * alpha);
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
activations.device(d) =
(features > static_cast<T>(0)).select(features, features * alpha);
}
};

View File

@ -406,6 +406,20 @@ class LeakyReluTest(test.TestCase):
self.evaluate(optimizer.minimize(loss))
self.assertAllClose(x.read_value(), -99.9)
def testUnexpectedAlphaValue(self):
self.assertAllClose(
np.array([[-9.0, 0.7, -5.0, 0.3, -0.1], [0.1, -3.0, 0.5, -27.0, 0.9]]),
nn_ops.leaky_relu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.01],
[0.1, -0.3, 0.5, -2.7, 0.9]]),
alpha=10))
self.assertAllClose(
np.array([[9.0, 0.7, 5.0, 0.3, 0.1], [0.1, 3.0, 0.5, 27.0, 0.9]]),
nn_ops.leaky_relu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.01],
[0.1, -0.3, 0.5, -2.7, 0.9]]),
alpha=-10))
class EluTest(test.TestCase):

View File

@ -23,7 +23,6 @@ import numbers
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -2759,12 +2758,9 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.cast(features, dtypes.float32)
if compat.forward_compatible(2018, 11, 1):
if isinstance(alpha, np.ndarray):
alpha = alpha.item()
return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features, name=name)
if isinstance(alpha, np.ndarray):
alpha = alpha.item()
return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
def _flatten_outer_dims(logits):