Optimize the gradient function for tf.pow() and tf.squared_difference().
This change uses the recently added graph-level cache for the results of `broadcast_gradient_args()` when the inputs have statically known shapes. Using this cache, it can avoid generating unnecessary ops, which shrinks the graph and improves startup time. In addition, we add special cases to optimize `tf.pow(x, scalar)`, and non-broadcasting `tf.squared_difference(x, y)`. PiperOrigin-RevId: 260638889
This commit is contained in:
parent
e5bfe5636c
commit
c99fc34eeb
@ -1327,21 +1327,43 @@ def _PowGrad(op, grad):
|
||||
"""Returns grad * (y*x^(y-1), z*log(x))."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
z = op.outputs[0]
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
use_mul_no_nan = compat.forward_compatible(2019, 9, 14)
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
# TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
|
||||
# constant `1` into a single constant op.
|
||||
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||
y):
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
z = math_ops.conj(z)
|
||||
|
||||
if compat.forward_compatible(2019, 9, 14):
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
|
||||
if use_mul_no_nan:
|
||||
return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None
|
||||
else:
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
|
||||
return grad * y * math_ops.pow(x, y - 1), None
|
||||
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
|
||||
if skip_input_indices is None or 0 not in skip_input_indices:
|
||||
if use_mul_no_nan:
|
||||
gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad)
|
||||
else:
|
||||
gx = grad * y * math_ops.pow(x, y - 1)
|
||||
if must_reduce_x:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
|
||||
else:
|
||||
gx = None
|
||||
|
||||
if skip_input_indices is None or 1 not in skip_input_indices:
|
||||
z = math_ops.conj(op.outputs[0])
|
||||
|
||||
# Avoid false singularity at x = 0
|
||||
if x.dtype.is_complex:
|
||||
# real(x) < 0 is fine for the complex case
|
||||
@ -1351,11 +1373,15 @@ def _PowGrad(op, grad):
|
||||
mask = x > 0
|
||||
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
|
||||
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
||||
if compat.forward_compatible(2019, 9, 14):
|
||||
gy = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
|
||||
if use_mul_no_nan:
|
||||
gy = gen_math_ops.mul_no_nan(z * log_x, grad)
|
||||
else:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
|
||||
gy = grad * z * log_x
|
||||
if must_reduce_y:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
|
||||
else:
|
||||
gy = None
|
||||
|
||||
return gx, gy
|
||||
|
||||
|
||||
@ -1423,15 +1449,39 @@ def _SquaredDifferenceGrad(op, grad):
|
||||
"""Returns the gradient for (x-y)^2."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
|
||||
with ops.control_dependencies([grad]):
|
||||
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by
|
||||
# Tensor (not a number like 2.0) which causes it to convert to Tensor.
|
||||
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
|
||||
return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
|
||||
-array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
|
||||
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||
return x_grad, -x_grad
|
||||
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
|
||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||
gx = None
|
||||
elif must_reduce_x:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
|
||||
else:
|
||||
gx = x_grad
|
||||
|
||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||
gy = None
|
||||
elif must_reduce_y:
|
||||
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
|
||||
else:
|
||||
gy = -x_grad
|
||||
return (gx, gy)
|
||||
|
||||
|
||||
# Logical operations have no gradients.
|
||||
|
Loading…
Reference in New Issue
Block a user