diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 20edafeeeef..3d6a915e115 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1327,35 +1327,61 @@ def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] - z = op.outputs[0] - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + use_mul_no_nan = compat.forward_compatible(2019, 9, 14) + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the + # constant `1` into a single constant op. + if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( + y): + x = math_ops.conj(x) + y = math_ops.conj(y) + if use_mul_no_nan: + return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None + else: + return grad * y * math_ops.pow(x, y - 1), None + + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) x = math_ops.conj(x) y = math_ops.conj(y) - z = math_ops.conj(z) - if compat.forward_compatible(2019, 9, 14): - gx = array_ops.reshape( - math_ops.reduce_sum( - gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx) + if skip_input_indices is None or 0 not in skip_input_indices: + if use_mul_no_nan: + gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad) + else: + gx = grad * y * math_ops.pow(x, y - 1) + if must_reduce_x: + gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) else: - gx = array_ops.reshape( - math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) - # Avoid false singularity at x = 0 - if x.dtype.is_complex: - # real(x) < 0 is fine for the complex case - mask = math_ops.not_equal(x, 0) + gx = None + + if skip_input_indices is None or 1 not in skip_input_indices: + z = math_ops.conj(op.outputs[0]) + + # Avoid false singularity at x = 0 + if x.dtype.is_complex: + # real(x) < 0 is fine for the complex case + mask = math_ops.not_equal(x, 0) + else: + # There's no sensible real value to return if x < 0, so return 0 + mask = x > 0 + safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) + log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) + if use_mul_no_nan: + gy = gen_math_ops.mul_no_nan(z * log_x, grad) + else: + gy = grad * z * log_x + if must_reduce_y: + gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) else: - # There's no sensible real value to return if x < 0, so return 0 - mask = x > 0 - safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) - log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) - if compat.forward_compatible(2019, 9, 14): - gy = array_ops.reshape( - math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy) - else: - gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) + gy = None + return gx, gy @@ -1423,15 +1449,39 @@ def _SquaredDifferenceGrad(op, grad): """Returns the gradient for (x-y)^2.""" x = op.inputs[0] y = op.inputs[1] - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + with ops.control_dependencies([grad]): # The parens ensure that if grad is IndexedSlices, it'll get multiplied by # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) - return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), - -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) + + if (isinstance(grad, ops.Tensor) and + _ShapesFullySpecifiedAndEqual(x, y, grad)): + return x_grad, -x_grad + + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) + + if skip_input_indices is not None and 0 in skip_input_indices: + gx = None + elif must_reduce_x: + gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) + else: + gx = x_grad + + if skip_input_indices is not None and 1 in skip_input_indices: + gy = None + elif must_reduce_y: + gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) + else: + gy = -x_grad + return (gx, gy) # Logical operations have no gradients.