Optimize the gradient function for tf.pow() and tf.squared_difference().

This change uses the recently added graph-level cache for the results of `broadcast_gradient_args()` when the inputs have statically known shapes. Using this cache, it can avoid generating unnecessary ops, which shrinks the graph and improves startup time. In addition, we add special cases to optimize `tf.pow(x, scalar)`, and non-broadcasting `tf.squared_difference(x, y)`.

PiperOrigin-RevId: 260638889
This commit is contained in:
Derek Murray 2019-07-29 20:46:32 -07:00 committed by TensorFlower Gardener
parent e5bfe5636c
commit c99fc34eeb

View File

@ -1327,21 +1327,43 @@ def _PowGrad(op, grad):
"""Returns grad * (y*x^(y-1), z*log(x)).""" """Returns grad * (y*x^(y-1), z*log(x))."""
x = op.inputs[0] x = op.inputs[0]
y = op.inputs[1] y = op.inputs[1]
z = op.outputs[0] use_mul_no_nan = compat.forward_compatible(2019, 9, 14)
sx = array_ops.shape(x) skip_input_indices = None
sy = array_ops.shape(y) try:
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) skip_input_indices = op.skip_input_indices
# TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
# constant `1` into a single constant op.
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
x = math_ops.conj(x) x = math_ops.conj(x)
y = math_ops.conj(y) y = math_ops.conj(y)
z = math_ops.conj(z) if use_mul_no_nan:
return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None
if compat.forward_compatible(2019, 9, 14):
gx = array_ops.reshape(
math_ops.reduce_sum(
gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
else: else:
gx = array_ops.reshape( return grad * y * math_ops.pow(x, y - 1), None
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
x = math_ops.conj(x)
y = math_ops.conj(y)
if skip_input_indices is None or 0 not in skip_input_indices:
if use_mul_no_nan:
gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad)
else:
gx = grad * y * math_ops.pow(x, y - 1)
if must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
else:
gx = None
if skip_input_indices is None or 1 not in skip_input_indices:
z = math_ops.conj(op.outputs[0])
# Avoid false singularity at x = 0 # Avoid false singularity at x = 0
if x.dtype.is_complex: if x.dtype.is_complex:
# real(x) < 0 is fine for the complex case # real(x) < 0 is fine for the complex case
@ -1351,11 +1373,15 @@ def _PowGrad(op, grad):
mask = x > 0 mask = x > 0
safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
if compat.forward_compatible(2019, 9, 14): if use_mul_no_nan:
gy = array_ops.reshape( gy = gen_math_ops.mul_no_nan(z * log_x, grad)
math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
else: else:
gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) gy = grad * z * log_x
if must_reduce_y:
gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
else:
gy = None
return gx, gy return gx, gy
@ -1423,15 +1449,39 @@ def _SquaredDifferenceGrad(op, grad):
"""Returns the gradient for (x-y)^2.""" """Returns the gradient for (x-y)^2."""
x = op.inputs[0] x = op.inputs[0]
y = op.inputs[1] y = op.inputs[1]
sx = array_ops.shape(x) skip_input_indices = None
sy = array_ops.shape(y) try:
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) skip_input_indices = op.skip_input_indices
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
with ops.control_dependencies([grad]): with ops.control_dependencies([grad]):
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
# Tensor (not a number like 2.0) which causes it to convert to Tensor. # Tensor (not a number like 2.0) which causes it to convert to Tensor.
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
-array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return x_grad, -x_grad
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
else:
gx = x_grad
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif must_reduce_y:
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
else:
gy = -x_grad
return (gx, gy)
# Logical operations have no gradients. # Logical operations have no gradients.