From e028703c1c69c7fefa77b8d79d45cad169ddbdca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Mar 2019 14:16:50 -0700 Subject: [PATCH] Fix bug in robust gradient of Div, RealDiv and DivNoNan. PiperOrigin-RevId: 239267389 --- .../python/kernel_tests/cwise_ops_test.py | 1 + tensorflow/python/ops/math_grad.py | 25 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index b3984ff864c..420cb7a5c81 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -1139,6 +1139,7 @@ class SingularGradientOpTest(test.TestCase): (gen_math_ops.acos, (1.,)), (gen_math_ops.atan2, (0., 0.)), (gen_math_ops.div, (1., 0.)), + (gen_math_ops.div_no_nan, (1., 0.)), (gen_math_ops.real_div, (1., 0.)), (math_ops.pow, (0., -1.)), ] diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 29e456584d6..43c8a4994c3 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1056,7 +1056,7 @@ def _DivGrad(op, grad): y = math_ops.conj(y) if compat.forward_compatible(2019, 4, 7): return (array_ops.reshape( - math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), + math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), array_ops.reshape( math_ops.reduce_sum( math_ops.mul_no_nan( @@ -1109,7 +1109,7 @@ def _RealDivGrad(op, grad): y = math_ops.conj(y) if compat.forward_compatible(2019, 4, 7): return (array_ops.reshape( - math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), + math_ops.reduce_sum(math_ops.xdivy(grad, y), rx), sx), array_ops.reshape( math_ops.reduce_sum( math_ops.mul_no_nan( @@ -1134,12 +1134,21 @@ def _DivNoNanGrad(op, grad): rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape( - math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum( - grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), - ry), sy)) + if compat.forward_compatible(2019, 4, 7): + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + math_ops.mul_no_nan( + math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), + grad), ry), sy)) + else: + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), + ry), sy)) @ops.RegisterGradient("Pow")