* Implemented selu activation #10612 * fix the error in _SeluGradGrad * update golden file for api change * add XLA kernels for Selu and SeluGrad
This commit is contained in:
parent
80d57aeadd
commit
c2ce4f68c7
@ -86,6 +86,15 @@ Status EluGradHelper(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
|
||||
|
||||
Status SeluGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -103,5 +103,15 @@ TEST_F(NNGradTest, EluGrad) {
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, SeluGrad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Selu(scope_, x);
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
|
||||
{5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -177,6 +177,7 @@ op { name: "MaxPoolGradWithArgmax" hide: true }
|
||||
op { name: "ReluGrad" hide: true }
|
||||
op { name: "Relu6Grad" hide: true }
|
||||
op { name: "EluGrad" hide: true }
|
||||
op { name: "SeluGrad" hide: true }
|
||||
op { name: "SoftplusGrad" hide: true }
|
||||
op { name: "SoftsignGrad" hide: true }
|
||||
op { name: "FractionalAvgPoolGrad" hide: true }
|
||||
|
@ -113,6 +113,14 @@ class BinaryOpsTest(XLATestCase):
|
||||
np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
|
||||
expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._selu_grad,
|
||||
np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
|
||||
np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype),
|
||||
expected=np.array(
|
||||
[1.158099340847, 2.7161986816948, 4.67429802254,
|
||||
4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._relu_grad,
|
||||
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
|
||||
|
@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Selu) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SeluGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Equal) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
|
@ -229,6 +229,11 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-1, 0, 1]], dtype=dtype),
|
||||
expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.selu,
|
||||
np.array([[-1, 0, 1]], dtype=dtype),
|
||||
expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.relu,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
|
@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel {
|
||||
REGISTER_XLA_OP(Name("Elu"), EluOp);
|
||||
REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
|
||||
|
||||
class SeluOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
// Computes the max of the scalar input x and 0.
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
const auto zero = XlaHelpers::Zero(b, input_type(0));
|
||||
const auto one = XlaHelpers::One(b, input_type(0));
|
||||
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
|
||||
1.0507009873554804934193349852946);
|
||||
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
|
||||
1.7580993408473768599402175208123);
|
||||
const auto pred = b->Gt(ctx->Input(0), zero);
|
||||
const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
|
||||
ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)),
|
||||
b->Mul(scale_alpha, expm1)));
|
||||
}
|
||||
};
|
||||
|
||||
class SeluGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
|
||||
// otherwise return lhs * (1 + rhs).
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
const auto zero = XlaHelpers::Zero(b, input_type(0));
|
||||
const auto one = XlaHelpers::One(b, input_type(0));
|
||||
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
|
||||
1.0507009873554804934193349852946);
|
||||
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
|
||||
1.7580993408473768599402175208123);
|
||||
const auto grad = ctx->Input(0);
|
||||
const auto activation = ctx->Input(1);
|
||||
const auto lin_grad = b->Mul(grad, scale);
|
||||
const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha));
|
||||
const auto pred = b->Gt(activation, zero);
|
||||
ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Selu"), SeluOp);
|
||||
REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -56,9 +56,15 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
|
||||
EluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
EluGradOp<CPUDevice, type>)
|
||||
EluGradOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Selu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SeluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SeluGradOp<CPUDevice, type>)
|
||||
|
||||
// Elu only makes sense with float or double.
|
||||
// Elu and Selu only make sense with float or double.
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
|
||||
#undef REGISTER_ELU_KERNELS
|
||||
|
||||
@ -103,7 +109,23 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
|
||||
typename TTypes<T>::ConstTensor activations, \
|
||||
typename TTypes<T>::Tensor backprops); \
|
||||
extern template struct EluGrad<GPUDevice, T>;
|
||||
extern template struct EluGrad<GPUDevice, T>; \
|
||||
\
|
||||
template <> \
|
||||
void Selu<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, \
|
||||
typename TTypes<T>::ConstTensor features, \
|
||||
typename TTypes<T>::Tensor activations); \
|
||||
extern template struct Selu<GPUDevice, T>; \
|
||||
\
|
||||
template <> \
|
||||
void SeluGrad<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
|
||||
typename TTypes<T>::ConstTensor activations, \
|
||||
typename TTypes<T>::Tensor backprops); \
|
||||
extern template struct SeluGrad<GPUDevice, T>;
|
||||
|
||||
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
} // namespace functor
|
||||
@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
EluOp<GPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
EluGradOp<GPUDevice, type>)
|
||||
EluGradOp<GPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
SeluOp<GPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
SeluGradOp<GPUDevice, type>)
|
||||
|
||||
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
EluOp<SYCLDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
EluGradOp<SYCLDevice, type>)
|
||||
EluGradOp<SYCLDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
SeluOp<SYCLDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
SeluGradOp<SYCLDevice, type>)
|
||||
|
||||
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
|
||||
#undef REGISTER_SYCL_KERNELS
|
||||
|
@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
|
||||
public:
|
||||
using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
|
||||
|
||||
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
|
||||
functor::Selu<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), input.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
|
||||
public:
|
||||
using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (outputs): outputs of the SeluOp()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::SeluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#undef EIGEN_USE_THREADS
|
||||
|
@ -125,6 +125,46 @@ struct EluGrad {
|
||||
}
|
||||
};
|
||||
|
||||
// Functor used by SeluOp to do the computations.
|
||||
template <typename Device, typename T>
|
||||
struct Selu {
|
||||
// Computes Selu activation.
|
||||
//
|
||||
// features: any shape.
|
||||
// activations: same shape as "features".
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
|
||||
typename TTypes<T>::Tensor activations) {
|
||||
// features.constant(?)
|
||||
const auto scale = static_cast<T>(1.0507009873554804934193349852946);
|
||||
const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
|
||||
const auto one = static_cast<T>(1);
|
||||
const auto zero = static_cast<T>(0);
|
||||
activations.device(d) =
|
||||
(features < zero)
|
||||
.select(scale_alpha * (features.exp() - features.constant(one)),
|
||||
scale * features);
|
||||
}
|
||||
};
|
||||
|
||||
// Functor used by SeluGradOp to do the computations.
|
||||
template <typename Device, typename T>
|
||||
struct SeluGrad {
|
||||
// Computes SeluGrad backprops.
|
||||
//
|
||||
// gradients: gradients backpropagated to the Selu op.
|
||||
// activations: outputs of the Selu op.
|
||||
// backprops: gradients to backpropagate to the Selu inputs.
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
|
||||
typename TTypes<T>::ConstTensor activations,
|
||||
typename TTypes<T>::Tensor backprops) {
|
||||
const auto scale = static_cast<T>(1.0507009873554804934193349852946);
|
||||
const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
|
||||
backprops.device(d) =
|
||||
(activations < static_cast<T>(0)).select(
|
||||
gradients * (activations + scale_alpha), gradients * scale);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -35,7 +35,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::Relu6<GPUDevice, T>; \
|
||||
template struct functor::Relu6Grad<GPUDevice, T>; \
|
||||
template struct functor::Elu<GPUDevice, T>; \
|
||||
template struct functor::EluGrad<GPUDevice, T>;
|
||||
template struct functor::EluGrad<GPUDevice, T>; \
|
||||
template struct functor::Selu<GPUDevice, T>; \
|
||||
template struct functor::SeluGrad<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
|
||||
|
@ -1779,6 +1779,33 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
|
||||
`gradients` otherwise.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Selu")
|
||||
.Input("features: T")
|
||||
.Output("activations: T")
|
||||
.Attr("T: {half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
|
||||
if < 0, `scale * features` otherwise.
|
||||
|
||||
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SeluGrad")
|
||||
.Input("gradients: T")
|
||||
.Input("outputs: T")
|
||||
.Output("backprops: T")
|
||||
.Attr("T: {half, float, double}")
|
||||
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes gradients for the scaled exponential linear (Selu) operation.
|
||||
|
||||
gradients: The backpropagated gradients to the corresponding Selu operation.
|
||||
outputs: The outputs of the corresponding Selu operation.
|
||||
backprops: The gradients: `gradients * (outputs + scale * alpha)`
|
||||
if outputs < 0, `scale * gradients` otherwise.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Softplus")
|
||||
.Input("features: T")
|
||||
.Output("activations: T")
|
||||
|
@ -412,7 +412,8 @@ TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
|
||||
|
||||
TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
|
||||
for (const char* op_name :
|
||||
{"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) {
|
||||
{"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad", "SoftplusGrad",
|
||||
"SoftsignGrad"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
|
||||
INFER_OK(op, "?;?", "in0|in1");
|
||||
|
@ -23383,6 +23383,60 @@ op {
|
||||
summary: "Computes the eigen decomposition of one or more square self-adjoint matrices."
|
||||
description: "Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in\n`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.\n\n```python\n# a is a tensor.\n# e is a tensor of eigenvalues.\n# v is a tensor of eigenvectors.\ne, v = self_adjoint_eig(a)\ne = self_adjoint_eig(a, compute_v=False)\n```"
|
||||
}
|
||||
op {
|
||||
name: "Selu"
|
||||
input_arg {
|
||||
name: "features"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "activations"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` if < 0, `scale * features` otherwise."
|
||||
description: "See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)"
|
||||
}
|
||||
op {
|
||||
name: "SeluGrad"
|
||||
input_arg {
|
||||
name: "gradients"
|
||||
description: "The backpropagated gradients to the corresponding Selu operation."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "outputs"
|
||||
description: "The outputs of the corresponding Selu operation."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "backprops"
|
||||
description: "The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,\n`scale * gradients` otherwise."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes gradients for the scaled exponential linear (Selu) operation."
|
||||
}
|
||||
op {
|
||||
name: "SerializeManySparse"
|
||||
input_arg {
|
||||
|
@ -8,7 +8,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
|
||||
## Activation Functions
|
||||
|
||||
The activation ops provide different types of nonlinearities for use in neural
|
||||
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`,
|
||||
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, `selu`,
|
||||
`softplus`, and `softsign`), continuous but not everywhere differentiable
|
||||
functions (`relu`, `relu6`, `crelu` and `relu_x`), and random regularization
|
||||
(`dropout`).
|
||||
@ -20,6 +20,7 @@ shape as the input tensor.
|
||||
* @{tf.nn.relu6}
|
||||
* @{tf.nn.crelu}
|
||||
* @{tf.nn.elu}
|
||||
* @{tf.nn.selu}
|
||||
* @{tf.nn.softplus}
|
||||
* @{tf.nn.softsign}
|
||||
* @{tf.nn.dropout}
|
||||
|
@ -16265,6 +16265,28 @@ func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyReso
|
||||
return scope.AddOperation(opspec)
|
||||
}
|
||||
|
||||
// Computes gradients for the scaled exponential linear (Selu) operation.
|
||||
//
|
||||
// Arguments:
|
||||
// gradients: The backpropagated gradients to the corresponding Selu operation.
|
||||
// outputs: The outputs of the corresponding Selu operation.
|
||||
//
|
||||
// Returns The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,
|
||||
// `scale * gradients` otherwise.
|
||||
func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "SeluGrad",
|
||||
Input: []tf.Input{
|
||||
gradients, outputs,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||
//
|
||||
// The hash function is deterministic on the content of the string within the
|
||||
@ -20541,6 +20563,24 @@ func Elu(scope *Scope, features tf.Output) (activations tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes scaled exponential linear: `1.758099 * (exp(features) - 1)` if < 0,
|
||||
// `1.050701 * features` otherwise.
|
||||
//
|
||||
// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
|
||||
func Selu(scope *Scope, features tf.Output) (activations tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Selu",
|
||||
Input: []tf.Input{
|
||||
features,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes square of x element-wise.
|
||||
//
|
||||
// I.e., \\(y = x * x = x^2\\).
|
||||
|
@ -320,6 +320,97 @@ class EluTest(test.TestCase):
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
class SeluTest(test.TestCase):
|
||||
|
||||
def _npSelu(self, np_features):
|
||||
scale = 1.0507009873554804934193349852946
|
||||
scale_alpha = 1.7580993408473768599402175208123
|
||||
return np.where(np_features < 0, scale_alpha * (np.exp(np_features) - 1),
|
||||
scale * np_features)
|
||||
|
||||
def testNpSelu(self):
|
||||
self.assertAllClose(
|
||||
np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527],
|
||||
[0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
|
||||
self._npSelu(
|
||||
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
|
||||
])))
|
||||
|
||||
def _testSelu(self, np_features, use_gpu=False):
|
||||
np_selu = self._npSelu(np_features)
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
selu = nn_ops.selu(np_features)
|
||||
tf_selu = selu.eval()
|
||||
self.assertAllClose(np_selu, tf_selu)
|
||||
self.assertShapeEqual(np_selu, selu)
|
||||
|
||||
def testNumbers(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testSelu(
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
|
||||
use_gpu=False)
|
||||
self._testSelu(
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
|
||||
use_gpu=True)
|
||||
|
||||
def testGradientFloat32(self):
|
||||
with self.test_session():
|
||||
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
|
||||
x = constant_op.constant(x_val, name="x")
|
||||
y = nn_ops.selu(x, name="selu")
|
||||
x_init = np.asarray(x_val, dtype=np.float32, order="F")
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, [2, 5], y, [2, 5], x_init_value=x_init)
|
||||
print("selu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
with self.test_session():
|
||||
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
|
||||
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
|
||||
y = nn_ops.selu(x, name="selu")
|
||||
x_init = np.asarray(x_val, dtype=np.float64, order="F")
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, [2, 5], y, [2, 5], x_init_value=x_init)
|
||||
print("selu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(
|
||||
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
|
||||
shape=[2, 5],
|
||||
name="x")
|
||||
y = nn_ops.selu(x, name="selu")
|
||||
z = gradients_impl.gradients(y, x)
|
||||
x_init = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float32,
|
||||
order="F")
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, [2, 5], z[0], [2, 5], x_init_value=x_init)
|
||||
print("selu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(
|
||||
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
|
||||
shape=[2, 5],
|
||||
dtype=dtypes.float64,
|
||||
name="x")
|
||||
y = nn_ops.selu(x, name="selu")
|
||||
z = gradients_impl.gradients(y, x)
|
||||
x_init = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float64,
|
||||
order="F")
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, [2, 5], z[0], [2, 5], x_init_value=x_init)
|
||||
print("selu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
class CreluTest(test.TestCase):
|
||||
|
||||
def testCreluShape(self):
|
||||
|
@ -290,6 +290,7 @@ MaxPool3DGradGrad
|
||||
ReluGrad
|
||||
Relu6Grad
|
||||
EluGrad
|
||||
SeluGrad
|
||||
SoftplusGrad
|
||||
SoftsignGrad
|
||||
TopK
|
||||
|
@ -22,6 +22,7 @@ See the @{$python/nn} guide.
|
||||
@@relu6
|
||||
@@crelu
|
||||
@@elu
|
||||
@@selu
|
||||
@@softplus
|
||||
@@softsign
|
||||
@@dropout
|
||||
|
@ -335,6 +335,16 @@ def _EluGradGrad(op, grad):
|
||||
dtype=elu_x.dtype)))
|
||||
|
||||
|
||||
@ops.RegisterGradient("SeluGrad")
|
||||
def _SeluGradGrad(op, grad):
|
||||
x = op.inputs[1]
|
||||
scale_alpha = 1.7580993408473768599402175208123
|
||||
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
|
||||
array_ops.where(
|
||||
x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha),
|
||||
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Relu6")
|
||||
def _Relu6Grad(op, grad):
|
||||
return gen_nn_ops._relu6_grad(grad, op.inputs[0])
|
||||
@ -345,6 +355,11 @@ def _EluGrad(op, grad):
|
||||
return gen_nn_ops._elu_grad(grad, op.outputs[0])
|
||||
|
||||
|
||||
@ops.RegisterGradient("Selu")
|
||||
def _SeluGrad(op, grad):
|
||||
return gen_nn_ops._selu_grad(grad, op.outputs[0])
|
||||
|
||||
|
||||
@ops.RegisterGradient("Softplus")
|
||||
def _SoftplusGrad(op, grad):
|
||||
return gen_nn_ops._softplus_grad(grad, op.inputs[0])
|
||||
|
@ -256,6 +256,10 @@ tf_module {
|
||||
name: "sampled_softmax_loss"
|
||||
argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "selu"
|
||||
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "separable_conv2d"
|
||||
argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user