diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 6ac333b141f..7666ba23eae 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -410,20 +410,17 @@ def _ReluGrad(op, grad): @ops.RegisterGradient("EluGrad") def _EluGradGrad(op, grad): elu_x = op.inputs[1] - return (gen_nn_ops.elu_grad(grad, op.outputs[0]), + return (gen_nn_ops.elu_grad(grad, elu_x), array_ops.where( - elu_x < 0, grad * op.inputs[0], - array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype))) + elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x))) @ops.RegisterGradient("SeluGrad") def _SeluGradGrad(op, grad): - x = op.inputs[1] - scale_alpha = 1.7580993408473768599402175208123 - return (gen_nn_ops.elu_grad(grad, op.outputs[0]), + selu_x = op.inputs[1] + return (gen_nn_ops.selu_grad(grad, selu_x), array_ops.where( - x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) + selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x))) @ops.RegisterGradient("Relu6") diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py index 783656a8693..9da56cb7200 100644 --- a/tensorflow/python/ops/nn_grad_test.py +++ b/tensorflow/python/ops/nn_grad_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_grad # pylint: disable=unused-import @@ -157,5 +158,79 @@ class DepthwiseConv2dTest(test.TestCase): self.run_test(x, grad_wrt_filter) +class EluGradOpTest(test.TestCase): + + @test_util.run_deprecated_v1 + def testEluGradGradWRTgrad_ys(self): + inputs = constant_op.constant( + [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32) + dummy = constant_op.constant( + [[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32) + + elu = gen_nn_ops.elu(inputs) + elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0] + with self.cached_session(): + error = gradient_checker.compute_gradient_error( + dummy, + dummy.shape, + elu_grad, + elu_grad.shape) + self.assertLess(error, 1e-4) + + @test_util.run_deprecated_v1 + def testEluGradGradWRTinputs(self): + inputs = constant_op.constant( + [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32) + dummy = constant_op.constant( + [[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32) + + elu = gen_nn_ops.elu(inputs) + elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0] + with self.cached_session(): + error = gradient_checker.compute_gradient_error( + inputs, + inputs.shape, + elu_grad, + elu_grad.shape) + self.assertLess(error, 1e-4) + + +class SeluGradOpTest(test.TestCase): + + @test_util.run_deprecated_v1 + def testSeluGradGradWRTgrad_ys(self): + inputs = constant_op.constant( + [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32) + dummy = constant_op.constant( + [[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32) + + selu = gen_nn_ops.selu(inputs) + selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0] + with self.cached_session(): + error = gradient_checker.compute_gradient_error( + dummy, + dummy.shape, + selu_grad, + selu_grad.shape) + self.assertLess(error, 1e-4) + + @test_util.run_deprecated_v1 + def testSeluGradGradWRTinputs(self): + inputs = constant_op.constant( + [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32) + dummy = constant_op.constant( + [[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32) + + selu = gen_nn_ops.selu(inputs) + selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0] + with self.cached_session(): + error = gradient_checker.compute_gradient_error( + inputs, + inputs.shape, + selu_grad, + selu_grad.shape) + self.assertLess(error, 1e-4) + + if __name__ == "__main__": test.main()