parent
623abf22f8
commit
d8563bdb14
@ -1327,43 +1327,21 @@ def _PowGrad(op, grad):
|
||||
"""Returns grad * (y*x^(y-1), z*log(x))."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
use_mul_no_nan = compat.forward_compatible(2019, 9, 14)
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
# TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
|
||||
# constant `1` into a single constant op.
|
||||
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||
y):
|
||||
z = op.outputs[0]
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
if use_mul_no_nan:
|
||||
return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None
|
||||
z = math_ops.conj(z)
|
||||
|
||||
if compat.forward_compatible(2019, 9, 14):
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
||||
else:
|
||||
return grad * y * math_ops.pow(x, y - 1), None
|
||||
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
|
||||
if skip_input_indices is None or 0 not in skip_input_indices:
|
||||
if use_mul_no_nan:
|
||||
gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad)
|
||||
else:
|
||||
gx = grad * y * math_ops.pow(x, y - 1)
|
||||
if must_reduce_x:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
|
||||
else:
|
||||
gx = None
|
||||
|
||||
if skip_input_indices is None or 1 not in skip_input_indices:
|
||||
z = math_ops.conj(op.outputs[0])
|
||||
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
|
||||
# Avoid false singularity at x = 0
|
||||
if x.dtype.is_complex:
|
||||
# real(x) < 0 is fine for the complex case
|
||||
@ -1373,15 +1351,11 @@ def _PowGrad(op, grad):
|
||||
mask = x > 0
|
||||
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
|
||||
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
||||
if use_mul_no_nan:
|
||||
gy = gen_math_ops.mul_no_nan(z * log_x, grad)
|
||||
if compat.forward_compatible(2019, 9, 14):
|
||||
gy = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
||||
else:
|
||||
gy = grad * z * log_x
|
||||
if must_reduce_y:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
|
||||
else:
|
||||
gy = None
|
||||
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
|
||||
return gx, gy
|
||||
|
||||
|
||||
@ -1449,39 +1423,15 @@ def _SquaredDifferenceGrad(op, grad):
|
||||
"""Returns the gradient for (x-y)^2."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
with ops.control_dependencies([grad]):
|
||||
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by
|
||||
# Tensor (not a number like 2.0) which causes it to convert to Tensor.
|
||||
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
|
||||
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||
return x_grad, -x_grad
|
||||
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
|
||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||
gx = None
|
||||
elif must_reduce_x:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
|
||||
else:
|
||||
gx = x_grad
|
||||
|
||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||
gy = None
|
||||
elif must_reduce_y:
|
||||
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
|
||||
else:
|
||||
gy = -x_grad
|
||||
return (gx, gy)
|
||||
return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
|
||||
-array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
|
||||
|
||||
|
||||
# Logical operations have no gradients.
|
||||
|
Loading…
Reference in New Issue
Block a user