diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3d6a915e115..20edafeeeef 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1327,61 +1327,35 @@ def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] - use_mul_no_nan = compat.forward_compatible(2019, 9, 14) - skip_input_indices = None - try: - skip_input_indices = op.skip_input_indices - # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the - # constant `1` into a single constant op. - if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( - y): - x = math_ops.conj(x) - y = math_ops.conj(y) - if use_mul_no_nan: - return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None - else: - return grad * y * math_ops.pow(x, y - 1), None - - except AttributeError: - # No gradient skipping, so do the full gradient computation - pass - - (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( - SmartBroadcastGradientArgs(x, y, grad)) + z = op.outputs[0] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) + z = math_ops.conj(z) - if skip_input_indices is None or 0 not in skip_input_indices: - if use_mul_no_nan: - gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad) - else: - gx = grad * y * math_ops.pow(x, y - 1) - if must_reduce_x: - gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) + if compat.forward_compatible(2019, 9, 14): + gx = array_ops.reshape( + math_ops.reduce_sum( + gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx) else: - gx = None - - if skip_input_indices is None or 1 not in skip_input_indices: - z = math_ops.conj(op.outputs[0]) - - # Avoid false singularity at x = 0 - if x.dtype.is_complex: - # real(x) < 0 is fine for the complex case - mask = math_ops.not_equal(x, 0) - else: - # There's no sensible real value to return if x < 0, so return 0 - mask = x > 0 - safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) - log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) - if use_mul_no_nan: - gy = gen_math_ops.mul_no_nan(z * log_x, grad) - else: - gy = grad * z * log_x - if must_reduce_y: - gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) + gx = array_ops.reshape( + math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) + # Avoid false singularity at x = 0 + if x.dtype.is_complex: + # real(x) < 0 is fine for the complex case + mask = math_ops.not_equal(x, 0) else: - gy = None - + # There's no sensible real value to return if x < 0, so return 0 + mask = x > 0 + safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) + log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) + if compat.forward_compatible(2019, 9, 14): + gy = array_ops.reshape( + math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy) + else: + gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) return gx, gy @@ -1449,39 +1423,15 @@ def _SquaredDifferenceGrad(op, grad): """Returns the gradient for (x-y)^2.""" x = op.inputs[0] y = op.inputs[1] - skip_input_indices = None - try: - skip_input_indices = op.skip_input_indices - except AttributeError: - # No gradient skipping, so do the full gradient computation - pass - + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) with ops.control_dependencies([grad]): # The parens ensure that if grad is IndexedSlices, it'll get multiplied by # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) - - if (isinstance(grad, ops.Tensor) and - _ShapesFullySpecifiedAndEqual(x, y, grad)): - return x_grad, -x_grad - - (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( - SmartBroadcastGradientArgs(x, y, grad)) - - if skip_input_indices is not None and 0 in skip_input_indices: - gx = None - elif must_reduce_x: - gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) - else: - gx = x_grad - - if skip_input_indices is not None and 1 in skip_input_indices: - gy = None - elif must_reduce_y: - gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) - else: - gy = -x_grad - return (gx, gy) + return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), + -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)) # Logical operations have no gradients.