From c99fc34eeba30aa4fd0b72c32c76b6a630a560c5 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 29 Jul 2019 20:46:32 -0700 Subject: [PATCH] Optimize the gradient function for tf.pow() and tf.squared_difference(). This change uses the recently added graph-level cache for the results of `broadcast_gradient_args()` when the inputs have statically known shapes. Using this cache, it can avoid generating unnecessary ops, which shrinks the graph and improves startup time. In addition, we add special cases to optimize `tf.pow(x, scalar)`, and non-broadcasting `tf.squared_difference(x, y)`. PiperOrigin-RevId: 260638889 --- tensorflow/python/ops/math_grad.py | 108 +++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 29 deletions(-) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 20edafeeeef..3d6a915e115 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1327,35 +1327,61 @@ def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] - z = op.outputs[0] - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + use_mul_no_nan = compat.forward_compatible(2019, 9, 14) + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the + # constant `1` into a single constant op. + if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( + y): + x = math_ops.conj(x) + y = math_ops.conj(y) + if use_mul_no_nan: + return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None + else: + return grad * y * math_ops.pow(x, y - 1), None + + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) x = math_ops.conj(x) y = math_ops.conj(y) - z = math_ops.conj(z) - if compat.forward_compatible(2019, 9, 14): - gx = array_ops.reshape( - math_ops.reduce_sum( - gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx) + if skip_input_indices is None or 0 not in skip_input_indices: + if use_mul_no_nan: + gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad) + else: + gx = grad * y * math_ops.pow(x, y - 1) + if must_reduce_x: + gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) else: - gx = array_ops.reshape( - math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) - # Avoid false singularity at x = 0 - if x.dtype.is_complex: - # real(x) < 0 is fine for the complex case - mask = math_ops.not_equal(x, 0) + gx = None + + if skip_input_indices is None or 1 not in skip_input_indices: + z = math_ops.conj(op.outputs[0]) + + # Avoid false singularity at x = 0 + if x.dtype.is_complex: + # real(x) < 0 is fine for the complex case + mask = math_ops.not_equal(x, 0) + else: + # There's no sensible real value to return if x < 0, so return 0 + mask = x > 0 + safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) + log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) + if use_mul_no_nan: + gy = gen_math_ops.mul_no_nan(z * log_x, grad) + else: + gy = grad * z * log_x + if must_reduce_y: + gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) else: - # There's no sensible real value to return if x < 0, so return 0 - mask = x > 0 - safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) - log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) - if compat.forward_compatible(2019, 9, 14): - gy = array_ops.reshape( - math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy) - else: - gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) + gy = None + return gx, gy @@ -1423,15 +1449,39 @@ def _SquaredDifferenceGrad(op, grad): """Returns the gradient for (x-y)^2.""" x = op.inputs[0] y = op.inputs[1] - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + with ops.control_dependencies([grad]): # The parens ensure that if grad is IndexedSlices, it'll get multiplied by # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) - return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), - -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) + + if (isinstance(grad, ops.Tensor) and + _ShapesFullySpecifiedAndEqual(x, y, grad)): + return x_grad, -x_grad + + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) + + if skip_input_indices is not None and 0 in skip_input_indices: + gx = None + elif must_reduce_x: + gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) + else: + gx = x_grad + + if skip_input_indices is not None and 1 in skip_input_indices: + gy = None + elif must_reduce_y: + gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) + else: + gy = -x_grad + return (gx, gy) # Logical operations have no gradients.