From c2ce4f68c744e6d328746b144ff1fcf98ac99e6c Mon Sep 17 00:00:00 2001 From: Lakshay Garg Date: Wed, 26 Jul 2017 10:18:01 +0530 Subject: [PATCH] Implemented selu activation #10612 (#10818) * Implemented selu activation #10612 * fix the error in _SeluGradGrad * update golden file for api change * add XLA kernels for Selu and SeluGrad --- tensorflow/cc/gradients/nn_grad.cc | 9 ++ tensorflow/cc/gradients/nn_grad_test.cc | 10 ++ tensorflow/cc/ops/op_gen_overrides.pbtxt | 1 + tensorflow/compiler/tests/binary_ops_test.py | 8 ++ tensorflow/compiler/tests/randomized_tests.cc | 17 ++++ tensorflow/compiler/tests/unary_ops_test.py | 5 + tensorflow/compiler/tf2xla/kernels/elu_op.cc | 44 +++++++++ tensorflow/core/kernels/relu_op.cc | 60 +++++++++--- tensorflow/core/kernels/relu_op.h | 42 +++++++++ tensorflow/core/kernels/relu_op_functor.h | 40 ++++++++ tensorflow/core/kernels/relu_op_gpu.cu.cc | 4 +- tensorflow/core/ops/nn_ops.cc | 27 ++++++ tensorflow/core/ops/nn_ops_test.cc | 3 +- tensorflow/core/ops/ops.pbtxt | 54 +++++++++++ tensorflow/docs_src/api_guides/python/nn.md | 3 +- tensorflow/go/op/wrappers.go | 40 ++++++++ .../python/kernel_tests/relu_op_test.py | 91 +++++++++++++++++++ tensorflow/python/ops/hidden_ops.txt | 1 + tensorflow/python/ops/nn.py | 1 + tensorflow/python/ops/nn_grad.py | 15 +++ .../tools/api/golden/tensorflow.nn.pbtxt | 4 + 21 files changed, 465 insertions(+), 14 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 952b2015edf..f9d69ff8967 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -86,6 +86,15 @@ Status EluGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Elu", EluGradHelper); +Status SeluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Selu", SeluGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index daa87546ec0..eab5b446261 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -103,5 +103,15 @@ TEST_F(NNGradTest, EluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, SeluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Selu(scope_, x); + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index a1f79177f75..2252cbb2892 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -177,6 +177,7 @@ op { name: "MaxPoolGradWithArgmax" hide: true } op { name: "ReluGrad" hide: true } op { name: "Relu6Grad" hide: true } op { name: "EluGrad" hide: true } +op { name: "SeluGrad" hide: true } op { name: "SoftplusGrad" hide: true } op { name: "SoftsignGrad" hide: true } op { name: "FractionalAvgPoolGrad" hide: true } diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9eaede7f406..0bdbf53c39f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -113,6 +113,14 @@ class BinaryOpsTest(XLATestCase): np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + self._testBinary( + gen_nn_ops._selu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), + expected=np.array( + [1.158099340847, 2.7161986816948, 4.67429802254, + 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d3821ad02e5..825fd9de2eb 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) { }); } +TEST_F(OpTest, Selu) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SeluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ce35eb91975..81ff18f3023 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -229,6 +229,11 @@ class UnaryOpsTest(XLATestCase): np.array([[-1, 0, 1]], dtype=dtype), expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.selu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 62a5e1bd421..2fd27c5ca7e 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Elu"), EluOp); REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); +class SeluOp : public XlaOpKernel { + public: + explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), + b->Mul(scale_alpha, expm1))); + } +}; + +class SeluGradOp : public XlaOpKernel { + public: + explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto lin_grad = b->Mul(grad, scale); + const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Selu"), SeluOp); +REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index d8d30e87e22..afad288cc00 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -50,15 +50,21 @@ typedef Eigen::SyclDevice SYCLDevice; TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS -#define REGISTER_ELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_CPU).TypeConstraint("T"), \ - EluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - EluGradOp) +#define REGISTER_ELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_CPU).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_CPU).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + SeluGradOp) -// Elu only makes sense with float or double. +// Elu and Selu only make sense with float or double. TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS); #undef REGISTER_ELU_KERNELS @@ -103,7 +109,23 @@ namespace functor { const GPUDevice& d, typename TTypes::ConstTensor gradients, \ typename TTypes::ConstTensor activations, \ typename TTypes::Tensor backprops); \ - extern template struct EluGrad; + extern template struct EluGrad; \ + \ + template <> \ + void Selu::operator()( \ + const GPUDevice& d, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Selu; \ + \ + template <> \ + void SeluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor activations, \ + typename TTypes::Tensor backprops); \ + extern template struct SeluGrad; + + TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor @@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); EluOp); \ REGISTER_KERNEL_BUILDER( \ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - EluGradOp) + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + SeluGradOp) + + TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS @@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); EluOp); \ REGISTER_KERNEL_BUILDER( \ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - EluGradOp) + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SeluGradOp) + + TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 365c6201a54..e712b02bd78 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -173,6 +173,48 @@ void EluGradOp::OperateNoTemplate(OpKernelContext* context, output->flat()); } +template +class SeluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Selu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class SeluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the SeluOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void SeluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::SeluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + } // namespace tensorflow #undef EIGEN_USE_THREADS diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 633522920c8..9577b963c6b 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -125,6 +125,46 @@ struct EluGrad { } }; +// Functor used by SeluOp to do the computations. +template +struct Selu { + // Computes Selu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // features.constant(?) + const auto scale = static_cast(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast(1.7580993408473768599402175208123); + const auto one = static_cast(1); + const auto zero = static_cast(0); + activations.device(d) = + (features < zero) + .select(scale_alpha * (features.exp() - features.constant(one)), + scale * features); + } +}; + +// Functor used by SeluGradOp to do the computations. +template +struct SeluGrad { + // Computes SeluGrad backprops. + // + // gradients: gradients backpropagated to the Selu op. + // activations: outputs of the Selu op. + // backprops: gradients to backpropagate to the Selu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor activations, + typename TTypes::Tensor backprops) { + const auto scale = static_cast(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast(1.7580993408473768599402175208123); + backprops.device(d) = + (activations < static_cast(0)).select( + gradients * (activations + scale_alpha), gradients * scale); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 30c4a289f7f..ec09d8dfea5 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -35,7 +35,9 @@ typedef Eigen::GpuDevice GPUDevice; template struct functor::Relu6; \ template struct functor::Relu6Grad; \ template struct functor::Elu; \ - template struct functor::EluGrad; + template struct functor::EluGrad; \ + template struct functor::Selu; \ + template struct functor::SeluGrad; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 555b97f53b3..10187425214 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1779,6 +1779,33 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0, `gradients` otherwise. )doc"); +REGISTER_OP("Selu") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` +if < 0, `scale * features` otherwise. + +See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) +)doc"); + +REGISTER_OP("SeluGrad") + .Input("gradients: T") + .Input("outputs: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) + .Doc(R"doc( +Computes gradients for the scaled exponential linear (Selu) operation. + +gradients: The backpropagated gradients to the corresponding Selu operation. +outputs: The outputs of the corresponding Selu operation. +backprops: The gradients: `gradients * (outputs + scale * alpha)` +if outputs < 0, `scale * gradients` otherwise. +)doc"); + REGISTER_OP("Softplus") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index a60b1c37880..51e4f8bffe0 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -412,7 +412,8 @@ TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) { TEST(NNOpsTest, MergeBothInputs_ShapeFn) { for (const char* op_name : - {"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) { + {"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad", "SoftplusGrad", + "SoftsignGrad"}) { ShapeInferenceTestOp op(op_name); INFER_OK(op, "?;?", "in0|in1"); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 468434bd283..2839575ec72 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -23383,6 +23383,60 @@ op { summary: "Computes the eigen decomposition of one or more square self-adjoint matrices." description: "Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in\n`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.\n\n```python\n# a is a tensor.\n# e is a tensor of eigenvalues.\n# v is a tensor of eigenvectors.\ne, v = self_adjoint_eig(a)\ne = self_adjoint_eig(a, compute_v=False)\n```" } +op { + name: "Selu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` if < 0, `scale * features` otherwise." + description: "See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)" +} +op { + name: "SeluGrad" + input_arg { + name: "gradients" + description: "The backpropagated gradients to the corresponding Selu operation." + type_attr: "T" + } + input_arg { + name: "outputs" + description: "The outputs of the corresponding Selu operation." + type_attr: "T" + } + output_arg { + name: "backprops" + description: "The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,\n`scale * gradients` otherwise." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes gradients for the scaled exponential linear (Selu) operation." +} op { name: "SerializeManySparse" input_arg { diff --git a/tensorflow/docs_src/api_guides/python/nn.md b/tensorflow/docs_src/api_guides/python/nn.md index 4f188372a0f..75dbb04e7df 100644 --- a/tensorflow/docs_src/api_guides/python/nn.md +++ b/tensorflow/docs_src/api_guides/python/nn.md @@ -8,7 +8,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by ## Activation Functions The activation ops provide different types of nonlinearities for use in neural -networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, +networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, `selu`, `softplus`, and `softsign`), continuous but not everywhere differentiable functions (`relu`, `relu6`, `crelu` and `relu_x`), and random regularization (`dropout`). @@ -20,6 +20,7 @@ shape as the input tensor. * @{tf.nn.relu6} * @{tf.nn.crelu} * @{tf.nn.elu} +* @{tf.nn.selu} * @{tf.nn.softplus} * @{tf.nn.softsign} * @{tf.nn.dropout} diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 095cbbe637b..43e09c498c6 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -16265,6 +16265,28 @@ func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyReso return scope.AddOperation(opspec) } +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0, +// `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Converts each string in the input Tensor to its hash mod by a number of buckets. // // The hash function is deterministic on the content of the string within the @@ -20541,6 +20563,24 @@ func Elu(scope *Scope, features tf.Output) (activations tf.Output) { return op.Output(0) } +// Computes scaled exponential linear: `1.758099 * (exp(features) - 1)` if < 0, +// `1.050701 * features` otherwise. +// +// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) +func Selu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Selu", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes square of x element-wise. // // I.e., \\(y = x * x = x^2\\). diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 63ac7438438..8cd1f52d800 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -320,6 +320,97 @@ class EluTest(test.TestCase): self.assertLess(err, 1e-6) +class SeluTest(test.TestCase): + + def _npSelu(self, np_features): + scale = 1.0507009873554804934193349852946 + scale_alpha = 1.7580993408473768599402175208123 + return np.where(np_features < 0, scale_alpha * (np.exp(np_features) - 1), + scale * np_features) + + def testNpSelu(self): + self.assertAllClose( + np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527], + [0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]), + self._npSelu( + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] + ]))) + + def _testSelu(self, np_features, use_gpu=False): + np_selu = self._npSelu(np_features) + with self.test_session(use_gpu=use_gpu): + selu = nn_ops.selu(np_features) + tf_selu = selu.eval() + self.assertAllClose(np_selu, tf_selu) + self.assertShapeEqual(np_selu, selu) + + def testNumbers(self): + for t in [np.float16, np.float32, np.float64]: + self._testSelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=False) + self._testSelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=True) + + def testGradientFloat32(self): + with self.test_session(): + x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]] + x = constant_op.constant(x_val, name="x") + y = nn_ops.selu(x, name="selu") + x_init = np.asarray(x_val, dtype=np.float32, order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("selu (float32) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradientFloat64(self): + with self.test_session(): + x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]] + x = constant_op.constant(x_val, dtype=dtypes.float64, name="x") + y = nn_ops.selu(x, name="selu") + x_init = np.asarray(x_val, dtype=np.float64, order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("selu (float64) gradient err = ", err) + self.assertLess(err, 1e-6) + + def testGradGradFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.selu(x, name="selu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("selu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradGradFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.selu(x, name="selu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("selu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-6) + + class CreluTest(test.TestCase): def testCreluShape(self): diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index b45d8987a3e..51973164326 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -290,6 +290,7 @@ MaxPool3DGradGrad ReluGrad Relu6Grad EluGrad +SeluGrad SoftplusGrad SoftsignGrad TopK diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index d05cba2e930..60e9695dcb9 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -22,6 +22,7 @@ See the @{$python/nn} guide. @@relu6 @@crelu @@elu +@@selu @@softplus @@softsign @@dropout diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index f1453f9ef0d..50673ed4276 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -335,6 +335,16 @@ def _EluGradGrad(op, grad): dtype=elu_x.dtype))) +@ops.RegisterGradient("SeluGrad") +def _SeluGradGrad(op, grad): + x = op.inputs[1] + scale_alpha = 1.7580993408473768599402175208123 + return (gen_nn_ops._elu_grad(grad, op.outputs[0]), + array_ops.where( + x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) + + @ops.RegisterGradient("Relu6") def _Relu6Grad(op, grad): return gen_nn_ops._relu6_grad(grad, op.inputs[0]) @@ -345,6 +355,11 @@ def _EluGrad(op, grad): return gen_nn_ops._elu_grad(grad, op.outputs[0]) +@ops.RegisterGradient("Selu") +def _SeluGrad(op, grad): + return gen_nn_ops._selu_grad(grad, op.outputs[0]) + + @ops.RegisterGradient("Softplus") def _SoftplusGrad(op, grad): return gen_nn_ops._softplus_grad(grad, op.inputs[0]) diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 9f817beafd9..3beb95d25c1 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -256,6 +256,10 @@ tf_module { name: "sampled_softmax_loss" argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], " } + member_method { + name: "selu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "separable_conv2d" argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "