Correcting EluGradGrad: gradient of EluGrad w.r.t. to gradients pushed to it was not correctly defined.

Correcting SeluGradGrad: gradient of SeluGrad w.r.t. to gradients pushed to it was not correctly defined.

PiperOrigin-RevId: 245657610
This commit is contained in:
A. Unique TensorFlower 2019-04-28 13:30:56 -07:00 committed by TensorFlower Gardener
parent b1190ccd07
commit 84c5a4551e
2 changed files with 80 additions and 8 deletions

View File

@ -410,20 +410,17 @@ def _ReluGrad(op, grad):
@ops.RegisterGradient("EluGrad")
def _EluGradGrad(op, grad):
elu_x = op.inputs[1]
return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
return (gen_nn_ops.elu_grad(grad, elu_x),
array_ops.where(
elu_x < 0, grad * op.inputs[0],
array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x)))
@ops.RegisterGradient("SeluGrad")
def _SeluGradGrad(op, grad):
x = op.inputs[1]
scale_alpha = 1.7580993408473768599402175208123
return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
selu_x = op.inputs[1]
return (gen_nn_ops.selu_grad(grad, selu_x),
array_ops.where(
x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha),
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x)))
@ops.RegisterGradient("Relu6")

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@ -157,5 +158,79 @@ class DepthwiseConv2dTest(test.TestCase):
self.run_test(x, grad_wrt_filter)
class EluGradOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testEluGradGradWRTgrad_ys(self):
inputs = constant_op.constant(
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
dummy = constant_op.constant(
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
elu = gen_nn_ops.elu(inputs)
elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0]
with self.cached_session():
error = gradient_checker.compute_gradient_error(
dummy,
dummy.shape,
elu_grad,
elu_grad.shape)
self.assertLess(error, 1e-4)
@test_util.run_deprecated_v1
def testEluGradGradWRTinputs(self):
inputs = constant_op.constant(
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
dummy = constant_op.constant(
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
elu = gen_nn_ops.elu(inputs)
elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0]
with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs,
inputs.shape,
elu_grad,
elu_grad.shape)
self.assertLess(error, 1e-4)
class SeluGradOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testSeluGradGradWRTgrad_ys(self):
inputs = constant_op.constant(
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
dummy = constant_op.constant(
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
selu = gen_nn_ops.selu(inputs)
selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0]
with self.cached_session():
error = gradient_checker.compute_gradient_error(
dummy,
dummy.shape,
selu_grad,
selu_grad.shape)
self.assertLess(error, 1e-4)
@test_util.run_deprecated_v1
def testSeluGradGradWRTinputs(self):
inputs = constant_op.constant(
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
dummy = constant_op.constant(
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
selu = gen_nn_ops.selu(inputs)
selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0]
with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs,
inputs.shape,
selu_grad,
selu_grad.shape)
self.assertLess(error, 1e-4)
if __name__ == "__main__":
test.main()