Add forward compatibility guards around recent changes to singular gradients.

PiperOrigin-RevId: 238690453
This commit is contained in:
A. Unique TensorFlower 2019-03-15 12:34:21 -07:00 committed by TensorFlower Gardener
parent d4b1b32629
commit 1c1c24c76b
2 changed files with 60 additions and 18 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
@ -1115,6 +1116,9 @@ class SingularGradientOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testGradientAtSingularity(self):
if not compat.forward_compatible(2019, 4, 7):
self.skipTest("Skipping test for future functionality.")
ops_and_singularity = [
(gen_math_ops.reciprocal, (0.,)),
(gen_math_ops.rsqrt, (0.,)),

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -459,8 +460,12 @@ def _SqrtGradGrad(op, grad):
a = op.inputs[0]
y = op.outputs[0] # y = 0.5 * b / conj(a)
with ops.control_dependencies([grad]):
ga = gen_math_ops.xdivy(grad, a)
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
if compat.forward_compatible(2019, 4, 7):
ga = gen_math_ops.xdivy(grad, a)
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
else:
ga = grad / a
return -math_ops.conj(ga) * y, 0.5 * ga
@ops.RegisterGradient("Rsqrt")
@ -508,7 +513,10 @@ def _LogGrad(op, grad):
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return gen_math_ops.xdivy(grad, x)
if compat.forward_compatible(2019, 4, 7):
return gen_math_ops.xdivy(grad, x)
else:
return grad * math_ops.reciprocal(x)
@ops.RegisterGradient("Log1p")
@ -517,7 +525,10 @@ def _Log1pGrad(op, grad):
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return gen_math_ops.xdivy(grad, 1 + x)
if compat.forward_compatible(2019, 4, 7):
return gen_math_ops.xdivy(grad, 1 + x)
else:
return grad * math_ops.reciprocal(1 + x)
@ops.RegisterGradient("Xlogy")
@ -596,7 +607,10 @@ def _AcoshGrad(op, grad):
y = op.outputs[0]
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return math_ops.xdivy(grad, math_ops.sinh(y))
if compat.forward_compatible(2019, 4, 7):
return math_ops.xdivy(grad, math_ops.sinh(y))
else:
return grad / math_ops.sinh(y)
@ops.RegisterGradient("Atanh")
@ -831,7 +845,10 @@ def _TanGrad(op, grad):
x = math_ops.conj(x)
secx = math_ops.reciprocal(math_ops.cos(x))
secx2 = math_ops.square(secx)
return math_ops.mul_no_nan(secx2, grad)
if compat.forward_compatible(2019, 4, 7):
return math_ops.mul_no_nan(secx2, grad)
else:
return secx2 * grad
@ops.RegisterGradient("Asin")
@ -843,7 +860,11 @@ def _AsinGrad(op, grad):
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
return math_ops.xdivy(grad, den)
if compat.forward_compatible(2019, 4, 7):
return math_ops.xdivy(grad, den)
else:
inv = math_ops.reciprocal(den)
return grad * inv
@ops.RegisterGradient("Acos")
@ -855,7 +876,11 @@ def _AcosGrad(op, grad):
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
return -math_ops.xdivy(grad, den)
if compat.forward_compatible(2019, 4, 7):
return -math_ops.xdivy(grad, den)
else:
inv = math_ops.reciprocal(den)
return -grad * inv
@ops.RegisterGradient("Atan")
@ -876,7 +901,10 @@ def _Atan2Grad(op, grad):
y = op.inputs[0]
x = op.inputs[1]
with ops.control_dependencies([grad]):
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
if compat.forward_compatible(2019, 4, 7):
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
else:
grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
return x * grad_inv, -y * grad_inv
@ -978,11 +1006,13 @@ def _DivGrad(op, grad):
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
if compat.forward_compatible(2019, 4, 7):
div_op = math_ops.div_no_nan
else:
div_op = math_ops.divide
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.div_no_nan(math_ops.divide(-x, y), y), ry),
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
sy))
@ -1060,9 +1090,14 @@ def _PowGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
z = math_ops.conj(z)
gx = array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
if compat.forward_compatible(2019, 4, 7):
gx = array_ops.reshape(
math_ops.reduce_sum(
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
else:
gx = array_ops.reshape(
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
# Avoid false singularity at x = 0
if x.dtype.is_complex:
# real(x) < 0 is fine for the complex case
@ -1072,8 +1107,11 @@ def _PowGrad(op, grad):
mask = x > 0
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
if compat.forward_compatible(2019, 4, 7):
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
else:
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
return gx, gy