Implemented selu activation #10612 (#10818)

* Implemented selu activation #10612

* fix the error in _SeluGradGrad

* update golden file for api change

* add XLA kernels for Selu and SeluGrad
This commit is contained in:
Lakshay Garg 2017-07-26 10:18:01 +05:30 committed by Vijay Vasudevan
parent 80d57aeadd
commit c2ce4f68c7
21 changed files with 465 additions and 14 deletions

View File

@ -86,6 +86,15 @@ Status EluGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
Status SeluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0));
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -103,5 +103,15 @@ TEST_F(NNGradTest, EluGrad) {
RunTest(x, x_init_value, y, shape);
}
TEST_F(NNGradTest, SeluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Selu(scope_, x);
Tensor x_init_value = test::AsTensor<float>(
{-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
{5, 2});
RunTest(x, x_init_value, y, shape);
}
} // namespace
} // namespace tensorflow

View File

@ -177,6 +177,7 @@ op { name: "MaxPoolGradWithArgmax" hide: true }
op { name: "ReluGrad" hide: true }
op { name: "Relu6Grad" hide: true }
op { name: "EluGrad" hide: true }
op { name: "SeluGrad" hide: true }
op { name: "SoftplusGrad" hide: true }
op { name: "SoftsignGrad" hide: true }
op { name: "FractionalAvgPoolGrad" hide: true }

View File

@ -113,6 +113,14 @@ class BinaryOpsTest(XLATestCase):
np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
self._testBinary(
gen_nn_ops._selu_grad,
np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype),
expected=np.array(
[1.158099340847, 2.7161986816948, 4.67429802254,
4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype))
self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),

View File

@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) {
});
}
TEST_F(OpTest, Selu) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SeluGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});

View File

@ -229,6 +229,11 @@ class UnaryOpsTest(XLATestCase):
np.array([[-1, 0, 1]], dtype=dtype),
expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.selu,
np.array([[-1, 0, 1]], dtype=dtype),
expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),

View File

@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Elu"), EluOp);
REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
class SeluOp : public XlaOpKernel {
public:
explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
1.7580993408473768599402175208123);
const auto pred = b->Gt(ctx->Input(0), zero);
const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)),
b->Mul(scale_alpha, expm1)));
}
};
class SeluGradOp : public XlaOpKernel {
public:
explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
1.7580993408473768599402175208123);
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
const auto lin_grad = b->Mul(grad, scale);
const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha));
const auto pred = b->Gt(activation, zero);
ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad));
}
};
REGISTER_XLA_OP(Name("Selu"), SeluOp);
REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp);
} // namespace
} // namespace tensorflow

View File

@ -56,9 +56,15 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
EluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
EluGradOp<CPUDevice, type>)
EluGradOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Selu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SeluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SeluGradOp<CPUDevice, type>)
// Elu only makes sense with float or double.
// Elu and Selu only make sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS
@ -103,7 +109,23 @@ namespace functor {
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor activations, \
typename TTypes<T>::Tensor backprops); \
extern template struct EluGrad<GPUDevice, T>;
extern template struct EluGrad<GPUDevice, T>; \
\
template <> \
void Selu<GPUDevice, T>::operator()( \
const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Selu<GPUDevice, T>; \
\
template <> \
void SeluGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor activations, \
typename TTypes<T>::Tensor backprops); \
extern template struct SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
EluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
EluGradOp<GPUDevice, type>)
EluGradOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
EluOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
EluGradOp<SYCLDevice, type>)
EluGradOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS

View File

@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
output->flat<T>());
}
template <typename Device, typename T>
class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Selu<Device, T> functor;
functor(context->eigen_device<Device>(), input.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (outputs): outputs of the SeluOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::SeluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
} // namespace tensorflow
#undef EIGEN_USE_THREADS

View File

@ -125,6 +125,46 @@ struct EluGrad {
}
};
// Functor used by SeluOp to do the computations.
template <typename Device, typename T>
struct Selu {
// Computes Selu activation.
//
// features: any shape.
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
// features.constant(?)
const auto scale = static_cast<T>(1.0507009873554804934193349852946);
const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
const auto one = static_cast<T>(1);
const auto zero = static_cast<T>(0);
activations.device(d) =
(features < zero)
.select(scale_alpha * (features.exp() - features.constant(one)),
scale * features);
}
};
// Functor used by SeluGradOp to do the computations.
template <typename Device, typename T>
struct SeluGrad {
// Computes SeluGrad backprops.
//
// gradients: gradients backpropagated to the Selu op.
// activations: outputs of the Selu op.
// backprops: gradients to backpropagate to the Selu inputs.
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor activations,
typename TTypes<T>::Tensor backprops) {
const auto scale = static_cast<T>(1.0507009873554804934193349852946);
const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
backprops.device(d) =
(activations < static_cast<T>(0)).select(
gradients * (activations + scale_alpha), gradients * scale);
}
};
} // namespace functor
} // namespace tensorflow

View File

@ -35,7 +35,9 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::Relu6<GPUDevice, T>; \
template struct functor::Relu6Grad<GPUDevice, T>; \
template struct functor::Elu<GPUDevice, T>; \
template struct functor::EluGrad<GPUDevice, T>;
template struct functor::EluGrad<GPUDevice, T>; \
template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);

View File

@ -1779,6 +1779,33 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
`gradients` otherwise.
)doc");
REGISTER_OP("Selu")
.Input("features: T")
.Output("activations: T")
.Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
if < 0, `scale * features` otherwise.
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
)doc");
REGISTER_OP("SeluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
.Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the scaled exponential linear (Selu) operation.
gradients: The backpropagated gradients to the corresponding Selu operation.
outputs: The outputs of the corresponding Selu operation.
backprops: The gradients: `gradients * (outputs + scale * alpha)`
if outputs < 0, `scale * gradients` otherwise.
)doc");
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")

View File

@ -412,7 +412,8 @@ TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
for (const char* op_name :
{"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) {
{"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad", "SoftplusGrad",
"SoftsignGrad"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "in0|in1");

View File

@ -23383,6 +23383,60 @@ op {
summary: "Computes the eigen decomposition of one or more square self-adjoint matrices."
description: "Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in\n`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.\n\n```python\n# a is a tensor.\n# e is a tensor of eigenvalues.\n# v is a tensor of eigenvectors.\ne, v = self_adjoint_eig(a)\ne = self_adjoint_eig(a, compute_v=False)\n```"
}
op {
name: "Selu"
input_arg {
name: "features"
type_attr: "T"
}
output_arg {
name: "activations"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` if < 0, `scale * features` otherwise."
description: "See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)"
}
op {
name: "SeluGrad"
input_arg {
name: "gradients"
description: "The backpropagated gradients to the corresponding Selu operation."
type_attr: "T"
}
input_arg {
name: "outputs"
description: "The outputs of the corresponding Selu operation."
type_attr: "T"
}
output_arg {
name: "backprops"
description: "The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,\n`scale * gradients` otherwise."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes gradients for the scaled exponential linear (Selu) operation."
}
op {
name: "SerializeManySparse"
input_arg {

View File

@ -8,7 +8,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
## Activation Functions
The activation ops provide different types of nonlinearities for use in neural
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`,
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, `selu`,
`softplus`, and `softsign`), continuous but not everywhere differentiable
functions (`relu`, `relu6`, `crelu` and `relu_x`), and random regularization
(`dropout`).
@ -20,6 +20,7 @@ shape as the input tensor.
* @{tf.nn.relu6}
* @{tf.nn.crelu}
* @{tf.nn.elu}
* @{tf.nn.selu}
* @{tf.nn.softplus}
* @{tf.nn.softsign}
* @{tf.nn.dropout}

View File

@ -16265,6 +16265,28 @@ func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyReso
return scope.AddOperation(opspec)
}
// Computes gradients for the scaled exponential linear (Selu) operation.
//
// Arguments:
// gradients: The backpropagated gradients to the corresponding Selu operation.
// outputs: The outputs of the corresponding Selu operation.
//
// Returns The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,
// `scale * gradients` otherwise.
func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "SeluGrad",
Input: []tf.Input{
gradients, outputs,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Converts each string in the input Tensor to its hash mod by a number of buckets.
//
// The hash function is deterministic on the content of the string within the
@ -20541,6 +20563,24 @@ func Elu(scope *Scope, features tf.Output) (activations tf.Output) {
return op.Output(0)
}
// Computes scaled exponential linear: `1.758099 * (exp(features) - 1)` if < 0,
// `1.050701 * features` otherwise.
//
// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
func Selu(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "Selu",
Input: []tf.Input{
features,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Computes square of x element-wise.
//
// I.e., \\(y = x * x = x^2\\).

View File

@ -320,6 +320,97 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-6)
class SeluTest(test.TestCase):
def _npSelu(self, np_features):
scale = 1.0507009873554804934193349852946
scale_alpha = 1.7580993408473768599402175208123
return np.where(np_features < 0, scale_alpha * (np.exp(np_features) - 1),
scale * np_features)
def testNpSelu(self):
self.assertAllClose(
np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527],
[0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
self._npSelu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
])))
def _testSelu(self, np_features, use_gpu=False):
np_selu = self._npSelu(np_features)
with self.test_session(use_gpu=use_gpu):
selu = nn_ops.selu(np_features)
tf_selu = selu.eval()
self.assertAllClose(np_selu, tf_selu)
self.assertShapeEqual(np_selu, selu)
def testNumbers(self):
for t in [np.float16, np.float32, np.float64]:
self._testSelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=False)
self._testSelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
def testGradientFloat32(self):
with self.test_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, name="x")
y = nn_ops.selu(x, name="selu")
x_init = np.asarray(x_val, dtype=np.float32, order="F")
err = gradient_checker.compute_gradient_error(
x, [2, 5], y, [2, 5], x_init_value=x_init)
print("selu (float32) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
with self.test_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
y = nn_ops.selu(x, name="selu")
x_init = np.asarray(x_val, dtype=np.float64, order="F")
err = gradient_checker.compute_gradient_error(
x, [2, 5], y, [2, 5], x_init_value=x_init)
print("selu (float64) gradient err = ", err)
self.assertLess(err, 1e-6)
def testGradGradFloat32(self):
with self.test_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
name="x")
y = nn_ops.selu(x, name="selu")
z = gradients_impl.gradients(y, x)
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32,
order="F")
err = gradient_checker.compute_gradient_error(
x, [2, 5], z[0], [2, 5], x_init_value=x_init)
print("selu (float32) gradient of gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
with self.test_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
dtype=dtypes.float64,
name="x")
y = nn_ops.selu(x, name="selu")
z = gradients_impl.gradients(y, x)
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float64,
order="F")
err = gradient_checker.compute_gradient_error(
x, [2, 5], z[0], [2, 5], x_init_value=x_init)
print("selu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-6)
class CreluTest(test.TestCase):
def testCreluShape(self):

View File

@ -290,6 +290,7 @@ MaxPool3DGradGrad
ReluGrad
Relu6Grad
EluGrad
SeluGrad
SoftplusGrad
SoftsignGrad
TopK

View File

@ -22,6 +22,7 @@ See the @{$python/nn} guide.
@@relu6
@@crelu
@@elu
@@selu
@@softplus
@@softsign
@@dropout

View File

@ -335,6 +335,16 @@ def _EluGradGrad(op, grad):
dtype=elu_x.dtype)))
@ops.RegisterGradient("SeluGrad")
def _SeluGradGrad(op, grad):
x = op.inputs[1]
scale_alpha = 1.7580993408473768599402175208123
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
array_ops.where(
x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha),
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
@ops.RegisterGradient("Relu6")
def _Relu6Grad(op, grad):
return gen_nn_ops._relu6_grad(grad, op.inputs[0])
@ -345,6 +355,11 @@ def _EluGrad(op, grad):
return gen_nn_ops._elu_grad(grad, op.outputs[0])
@ops.RegisterGradient("Selu")
def _SeluGrad(op, grad):
return gen_nn_ops._selu_grad(grad, op.outputs[0])
@ops.RegisterGradient("Softplus")
def _SoftplusGrad(op, grad):
return gen_nn_ops._softplus_grad(grad, op.inputs[0])

View File

@ -256,6 +256,10 @@ tf_module {
name: "sampled_softmax_loss"
argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], "
}
member_method {
name: "selu"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "separable_conv2d"
argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "