Fix bug in robust gradient of Div, RealDiv and DivNoNan.
PiperOrigin-RevId: 239267389
This commit is contained in:
parent
d43453baa5
commit
e028703c1c
@ -1139,6 +1139,7 @@ class SingularGradientOpTest(test.TestCase):
|
||||
(gen_math_ops.acos, (1.,)),
|
||||
(gen_math_ops.atan2, (0., 0.)),
|
||||
(gen_math_ops.div, (1., 0.)),
|
||||
(gen_math_ops.div_no_nan, (1., 0.)),
|
||||
(gen_math_ops.real_div, (1., 0.)),
|
||||
(math_ops.pow, (0., -1.)),
|
||||
]
|
||||
|
@ -1056,7 +1056,7 @@ def _DivGrad(op, grad):
|
||||
y = math_ops.conj(y)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(
|
||||
@ -1109,7 +1109,7 @@ def _RealDivGrad(op, grad):
|
||||
y = math_ops.conj(y)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(
|
||||
@ -1134,12 +1134,21 @@ def _DivNoNanGrad(op, grad):
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
|
||||
ry), sy))
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(
|
||||
math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
|
||||
grad), ry), sy))
|
||||
else:
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
|
||||
ry), sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Pow")
|
||||
|
Loading…
x
Reference in New Issue
Block a user