Add forward compatibility guards around recent changes to singular gradients.
PiperOrigin-RevId: 238690453
This commit is contained in:
parent
d4b1b32629
commit
1c1c24c76b
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -1115,6 +1116,9 @@ class SingularGradientOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradientAtSingularity(self):
|
def testGradientAtSingularity(self):
|
||||||
|
if not compat.forward_compatible(2019, 4, 7):
|
||||||
|
self.skipTest("Skipping test for future functionality.")
|
||||||
|
|
||||||
ops_and_singularity = [
|
ops_and_singularity = [
|
||||||
(gen_math_ops.reciprocal, (0.,)),
|
(gen_math_ops.reciprocal, (0.,)),
|
||||||
(gen_math_ops.rsqrt, (0.,)),
|
(gen_math_ops.rsqrt, (0.,)),
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -459,8 +460,12 @@ def _SqrtGradGrad(op, grad):
|
|||||||
a = op.inputs[0]
|
a = op.inputs[0]
|
||||||
y = op.outputs[0] # y = 0.5 * b / conj(a)
|
y = op.outputs[0] # y = 0.5 * b / conj(a)
|
||||||
with ops.control_dependencies([grad]):
|
with ops.control_dependencies([grad]):
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
ga = gen_math_ops.xdivy(grad, a)
|
ga = gen_math_ops.xdivy(grad, a)
|
||||||
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
|
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
|
||||||
|
else:
|
||||||
|
ga = grad / a
|
||||||
|
return -math_ops.conj(ga) * y, 0.5 * ga
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Rsqrt")
|
@ops.RegisterGradient("Rsqrt")
|
||||||
@ -508,7 +513,10 @@ def _LogGrad(op, grad):
|
|||||||
x = op.inputs[0]
|
x = op.inputs[0]
|
||||||
with ops.control_dependencies([grad]):
|
with ops.control_dependencies([grad]):
|
||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return gen_math_ops.xdivy(grad, x)
|
return gen_math_ops.xdivy(grad, x)
|
||||||
|
else:
|
||||||
|
return grad * math_ops.reciprocal(x)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Log1p")
|
@ops.RegisterGradient("Log1p")
|
||||||
@ -517,7 +525,10 @@ def _Log1pGrad(op, grad):
|
|||||||
x = op.inputs[0]
|
x = op.inputs[0]
|
||||||
with ops.control_dependencies([grad]):
|
with ops.control_dependencies([grad]):
|
||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return gen_math_ops.xdivy(grad, 1 + x)
|
return gen_math_ops.xdivy(grad, 1 + x)
|
||||||
|
else:
|
||||||
|
return grad * math_ops.reciprocal(1 + x)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Xlogy")
|
@ops.RegisterGradient("Xlogy")
|
||||||
@ -596,7 +607,10 @@ def _AcoshGrad(op, grad):
|
|||||||
y = op.outputs[0]
|
y = op.outputs[0]
|
||||||
with ops.control_dependencies([grad]):
|
with ops.control_dependencies([grad]):
|
||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return math_ops.xdivy(grad, math_ops.sinh(y))
|
return math_ops.xdivy(grad, math_ops.sinh(y))
|
||||||
|
else:
|
||||||
|
return grad / math_ops.sinh(y)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Atanh")
|
@ops.RegisterGradient("Atanh")
|
||||||
@ -831,7 +845,10 @@ def _TanGrad(op, grad):
|
|||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
secx = math_ops.reciprocal(math_ops.cos(x))
|
secx = math_ops.reciprocal(math_ops.cos(x))
|
||||||
secx2 = math_ops.square(secx)
|
secx2 = math_ops.square(secx)
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return math_ops.mul_no_nan(secx2, grad)
|
return math_ops.mul_no_nan(secx2, grad)
|
||||||
|
else:
|
||||||
|
return secx2 * grad
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Asin")
|
@ops.RegisterGradient("Asin")
|
||||||
@ -843,7 +860,11 @@ def _AsinGrad(op, grad):
|
|||||||
x2 = math_ops.square(x)
|
x2 = math_ops.square(x)
|
||||||
one = constant_op.constant(1, dtype=grad.dtype)
|
one = constant_op.constant(1, dtype=grad.dtype)
|
||||||
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return math_ops.xdivy(grad, den)
|
return math_ops.xdivy(grad, den)
|
||||||
|
else:
|
||||||
|
inv = math_ops.reciprocal(den)
|
||||||
|
return grad * inv
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Acos")
|
@ops.RegisterGradient("Acos")
|
||||||
@ -855,7 +876,11 @@ def _AcosGrad(op, grad):
|
|||||||
x2 = math_ops.square(x)
|
x2 = math_ops.square(x)
|
||||||
one = constant_op.constant(1, dtype=grad.dtype)
|
one = constant_op.constant(1, dtype=grad.dtype)
|
||||||
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
return -math_ops.xdivy(grad, den)
|
return -math_ops.xdivy(grad, den)
|
||||||
|
else:
|
||||||
|
inv = math_ops.reciprocal(den)
|
||||||
|
return -grad * inv
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Atan")
|
@ops.RegisterGradient("Atan")
|
||||||
@ -876,7 +901,10 @@ def _Atan2Grad(op, grad):
|
|||||||
y = op.inputs[0]
|
y = op.inputs[0]
|
||||||
x = op.inputs[1]
|
x = op.inputs[1]
|
||||||
with ops.control_dependencies([grad]):
|
with ops.control_dependencies([grad]):
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
|
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
|
||||||
|
else:
|
||||||
|
grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
|
||||||
return x * grad_inv, -y * grad_inv
|
return x * grad_inv, -y * grad_inv
|
||||||
|
|
||||||
|
|
||||||
@ -978,11 +1006,13 @@ def _DivGrad(op, grad):
|
|||||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
return (array_ops.reshape(
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
div_op = math_ops.div_no_nan
|
||||||
|
else:
|
||||||
|
div_op = math_ops.divide
|
||||||
|
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
|
||||||
array_ops.reshape(
|
array_ops.reshape(
|
||||||
math_ops.reduce_sum(
|
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
|
||||||
grad * math_ops.div_no_nan(math_ops.divide(-x, y), y), ry),
|
|
||||||
sy))
|
sy))
|
||||||
|
|
||||||
|
|
||||||
@ -1060,9 +1090,14 @@ def _PowGrad(op, grad):
|
|||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
z = math_ops.conj(z)
|
z = math_ops.conj(z)
|
||||||
|
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
gx = array_ops.reshape(
|
gx = array_ops.reshape(
|
||||||
math_ops.reduce_sum(
|
math_ops.reduce_sum(
|
||||||
math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
||||||
|
else:
|
||||||
|
gx = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
|
||||||
# Avoid false singularity at x = 0
|
# Avoid false singularity at x = 0
|
||||||
if x.dtype.is_complex:
|
if x.dtype.is_complex:
|
||||||
# real(x) < 0 is fine for the complex case
|
# real(x) < 0 is fine for the complex case
|
||||||
@ -1072,8 +1107,11 @@ def _PowGrad(op, grad):
|
|||||||
mask = x > 0
|
mask = x > 0
|
||||||
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
|
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
|
||||||
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
||||||
|
if compat.forward_compatible(2019, 4, 7):
|
||||||
gy = array_ops.reshape(
|
gy = array_ops.reshape(
|
||||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
||||||
|
else:
|
||||||
|
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
|
||||||
return gx, gy
|
return gx, gy
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user