Add forward compatibility guards around recent changes to singular gradients.
PiperOrigin-RevId: 238690453
This commit is contained in:
parent
d4b1b32629
commit
1c1c24c76b
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import ops
|
||||
@ -1115,6 +1116,9 @@ class SingularGradientOpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientAtSingularity(self):
|
||||
if not compat.forward_compatible(2019, 4, 7):
|
||||
self.skipTest("Skipping test for future functionality.")
|
||||
|
||||
ops_and_singularity = [
|
||||
(gen_math_ops.reciprocal, (0.,)),
|
||||
(gen_math_ops.rsqrt, (0.,)),
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -459,8 +460,12 @@ def _SqrtGradGrad(op, grad):
|
||||
a = op.inputs[0]
|
||||
y = op.outputs[0] # y = 0.5 * b / conj(a)
|
||||
with ops.control_dependencies([grad]):
|
||||
ga = gen_math_ops.xdivy(grad, a)
|
||||
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
ga = gen_math_ops.xdivy(grad, a)
|
||||
return -gen_math_ops.mul_no_nan(y, math_ops.conj(ga)), 0.5 * ga
|
||||
else:
|
||||
ga = grad / a
|
||||
return -math_ops.conj(ga) * y, 0.5 * ga
|
||||
|
||||
|
||||
@ops.RegisterGradient("Rsqrt")
|
||||
@ -508,7 +513,10 @@ def _LogGrad(op, grad):
|
||||
x = op.inputs[0]
|
||||
with ops.control_dependencies([grad]):
|
||||
x = math_ops.conj(x)
|
||||
return gen_math_ops.xdivy(grad, x)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return gen_math_ops.xdivy(grad, x)
|
||||
else:
|
||||
return grad * math_ops.reciprocal(x)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Log1p")
|
||||
@ -517,7 +525,10 @@ def _Log1pGrad(op, grad):
|
||||
x = op.inputs[0]
|
||||
with ops.control_dependencies([grad]):
|
||||
x = math_ops.conj(x)
|
||||
return gen_math_ops.xdivy(grad, 1 + x)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return gen_math_ops.xdivy(grad, 1 + x)
|
||||
else:
|
||||
return grad * math_ops.reciprocal(1 + x)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Xlogy")
|
||||
@ -596,7 +607,10 @@ def _AcoshGrad(op, grad):
|
||||
y = op.outputs[0]
|
||||
with ops.control_dependencies([grad]):
|
||||
y = math_ops.conj(y)
|
||||
return math_ops.xdivy(grad, math_ops.sinh(y))
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return math_ops.xdivy(grad, math_ops.sinh(y))
|
||||
else:
|
||||
return grad / math_ops.sinh(y)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Atanh")
|
||||
@ -831,7 +845,10 @@ def _TanGrad(op, grad):
|
||||
x = math_ops.conj(x)
|
||||
secx = math_ops.reciprocal(math_ops.cos(x))
|
||||
secx2 = math_ops.square(secx)
|
||||
return math_ops.mul_no_nan(secx2, grad)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return math_ops.mul_no_nan(secx2, grad)
|
||||
else:
|
||||
return secx2 * grad
|
||||
|
||||
|
||||
@ops.RegisterGradient("Asin")
|
||||
@ -843,7 +860,11 @@ def _AsinGrad(op, grad):
|
||||
x2 = math_ops.square(x)
|
||||
one = constant_op.constant(1, dtype=grad.dtype)
|
||||
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
||||
return math_ops.xdivy(grad, den)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return math_ops.xdivy(grad, den)
|
||||
else:
|
||||
inv = math_ops.reciprocal(den)
|
||||
return grad * inv
|
||||
|
||||
|
||||
@ops.RegisterGradient("Acos")
|
||||
@ -855,7 +876,11 @@ def _AcosGrad(op, grad):
|
||||
x2 = math_ops.square(x)
|
||||
one = constant_op.constant(1, dtype=grad.dtype)
|
||||
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
||||
return -math_ops.xdivy(grad, den)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return -math_ops.xdivy(grad, den)
|
||||
else:
|
||||
inv = math_ops.reciprocal(den)
|
||||
return -grad * inv
|
||||
|
||||
|
||||
@ops.RegisterGradient("Atan")
|
||||
@ -876,7 +901,10 @@ def _Atan2Grad(op, grad):
|
||||
y = op.inputs[0]
|
||||
x = op.inputs[1]
|
||||
with ops.control_dependencies([grad]):
|
||||
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
grad_inv = math_ops.xdivy(grad, (math_ops.square(x) + math_ops.square(y)))
|
||||
else:
|
||||
grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
|
||||
return x * grad_inv, -y * grad_inv
|
||||
|
||||
|
||||
@ -978,11 +1006,13 @@ def _DivGrad(op, grad):
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
div_op = math_ops.div_no_nan
|
||||
else:
|
||||
div_op = math_ops.divide
|
||||
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.div_no_nan(math_ops.divide(-x, y), y), ry),
|
||||
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
|
||||
sy))
|
||||
|
||||
|
||||
@ -1060,9 +1090,14 @@ def _PowGrad(op, grad):
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
z = math_ops.conj(z)
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
||||
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
||||
else:
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
|
||||
# Avoid false singularity at x = 0
|
||||
if x.dtype.is_complex:
|
||||
# real(x) < 0 is fine for the complex case
|
||||
@ -1072,8 +1107,11 @@ def _PowGrad(op, grad):
|
||||
mask = x > 0
|
||||
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
|
||||
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
||||
gy = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
gy = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
||||
else:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
|
||||
return gx, gy
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user