Correcting EluGradGrad: gradient of EluGrad w.r.t. to gradients pushed to it was not correctly defined.
Correcting SeluGradGrad: gradient of SeluGrad w.r.t. to gradients pushed to it was not correctly defined. PiperOrigin-RevId: 245657610
This commit is contained in:
parent
b1190ccd07
commit
84c5a4551e
@ -410,20 +410,17 @@ def _ReluGrad(op, grad):
|
||||
@ops.RegisterGradient("EluGrad")
|
||||
def _EluGradGrad(op, grad):
|
||||
elu_x = op.inputs[1]
|
||||
return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
|
||||
return (gen_nn_ops.elu_grad(grad, elu_x),
|
||||
array_ops.where(
|
||||
elu_x < 0, grad * op.inputs[0],
|
||||
array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
|
||||
elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x)))
|
||||
|
||||
|
||||
@ops.RegisterGradient("SeluGrad")
|
||||
def _SeluGradGrad(op, grad):
|
||||
x = op.inputs[1]
|
||||
scale_alpha = 1.7580993408473768599402175208123
|
||||
return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
|
||||
selu_x = op.inputs[1]
|
||||
return (gen_nn_ops.selu_grad(grad, selu_x),
|
||||
array_ops.where(
|
||||
x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha),
|
||||
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
|
||||
selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x)))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Relu6")
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
@ -157,5 +158,79 @@ class DepthwiseConv2dTest(test.TestCase):
|
||||
self.run_test(x, grad_wrt_filter)
|
||||
|
||||
|
||||
class EluGradOpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEluGradGradWRTgrad_ys(self):
|
||||
inputs = constant_op.constant(
|
||||
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
|
||||
dummy = constant_op.constant(
|
||||
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
|
||||
|
||||
elu = gen_nn_ops.elu(inputs)
|
||||
elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0]
|
||||
with self.cached_session():
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
dummy,
|
||||
dummy.shape,
|
||||
elu_grad,
|
||||
elu_grad.shape)
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEluGradGradWRTinputs(self):
|
||||
inputs = constant_op.constant(
|
||||
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
|
||||
dummy = constant_op.constant(
|
||||
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
|
||||
|
||||
elu = gen_nn_ops.elu(inputs)
|
||||
elu_grad = gradients_impl.gradients(elu, inputs, grad_ys=dummy)[0]
|
||||
with self.cached_session():
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
inputs,
|
||||
inputs.shape,
|
||||
elu_grad,
|
||||
elu_grad.shape)
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
|
||||
class SeluGradOpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSeluGradGradWRTgrad_ys(self):
|
||||
inputs = constant_op.constant(
|
||||
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
|
||||
dummy = constant_op.constant(
|
||||
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
|
||||
|
||||
selu = gen_nn_ops.selu(inputs)
|
||||
selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0]
|
||||
with self.cached_session():
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
dummy,
|
||||
dummy.shape,
|
||||
selu_grad,
|
||||
selu_grad.shape)
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSeluGradGradWRTinputs(self):
|
||||
inputs = constant_op.constant(
|
||||
[[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
|
||||
dummy = constant_op.constant(
|
||||
[[3, 1, -1, -2], [9, 8, 7, 6]], dtype=dtypes.float32)
|
||||
|
||||
selu = gen_nn_ops.selu(inputs)
|
||||
selu_grad = gradients_impl.gradients(selu, inputs, grad_ys=dummy)[0]
|
||||
with self.cached_session():
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
inputs,
|
||||
inputs.shape,
|
||||
selu_grad,
|
||||
selu_grad.shape)
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user