diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ffb23868623..541c609c872 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2839,6 +2839,10 @@ class Graph(object): self._add_control_dependencies = False # Cache for OpDef protobufs retrieved via the C API. self._op_def_cache = {} + # Cache for constant results of `broadcast_gradient_args()`. The keys are + # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the + # values are tuples of reduction indices: (rx, ry). + self._bcast_grad_args_cache = {} # TODO(skyewm): fold as much of the above as possible into the C # implementation diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 31e5895fd0b..d7a5c02a6a7 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -52,6 +53,79 @@ def _ArgMinGrad(op, grad): ops.NotDifferentiable("EuclideanNorm") +def SmartBroadcastGradientArgs(x, y, grad): + """Optimized version of `broadcast_gradient_args` that caches results. + + This implementation avoids creating `broadcast_gradient_args` ops in the case + that the input shapes are fully defined, and provides hints to the calling + code that can be used to avoid creating reduction and reshaping ops. + + Args: + x: The left input tensor to a broadcasting binary op. + y: The right input tensor to a broadcasting binary op. + grad: The incoming gradient tensor for a broadcasting binary op. + + Returns: + A pair of tuples, containing: + * A 3-tuple of broadcast information for x, containing: + * The shape of x (as a tuple or Tensor). + * The reduction indices for x (as a tuple or Tensor). + * A boolean, which if True, indicates that x's shape differs from grad's + shape (and so x's gradient must be reduced and/or reshaped). + * A 3-tuple of broadcast information for y, containing the respective + details for y. + """ + # NOTE: It may be productive to apply these optimizations in the eager case + # as well. + if context.executing_eagerly() or not ( + isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) + and isinstance(grad, ops.Tensor)): + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + return (sx, rx, True), (sy, ry, True) + + # pylint: disable=protected-access + x_shape_tuple = x._shape_tuple() + y_shape_tuple = y._shape_tuple() + grad_shape_tuple = grad._shape_tuple() + # pylint: enable=protected-access + + if (x_shape_tuple is None or None in x_shape_tuple or + y_shape_tuple is None or None in y_shape_tuple): + sx = array_ops.shape_internal(x, optimize=False) + sy = array_ops.shape_internal(y, optimize=False) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + return (sx, rx, True), (sy, ry, True) + + x_needs_reduction = x_shape_tuple != grad_shape_tuple + y_needs_reduction = y_shape_tuple != grad_shape_tuple + + # Get the default graph rather than relying on `x.graph`, `y.graph`, or + # `grad.graph`, because these may be eager tensors. + g = ops.get_default_graph() + + try: + rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access + return (x_shape_tuple, rx, x_needs_reduction), ( + y_shape_tuple, ry, y_needs_reduction) + except KeyError: + rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple) + # TODO(mrry): If this becomes a bottleneck, add a multi-output version of + # `TF_TryEvaluateConstant()`. + rx_value = tuple(c_api.TF_TryEvaluateConstant_wrapper( + rx.graph._c_graph, rx._as_tf_output())) # pylint: disable=protected-access + assert rx_value is not None + ry_value = tuple(c_api.TF_TryEvaluateConstant_wrapper( + ry.graph._c_graph, ry._as_tf_output())) # pylint: disable=protected-access + assert ry_value is not None + g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access + rx_value, ry_value) + + return (x_shape_tuple, rx_value, x_needs_reduction), ( + y_shape_tuple, ry_value, y_needs_reduction) + + _empty_tuple = () @@ -1000,55 +1074,96 @@ def _AddGrad(op, grad): if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, grad - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) if skip_input_indices is not None and 0 in skip_input_indices: gx = None + elif not must_reduce_x: + gx = grad else: gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) if skip_input_indices is not None and 1 in skip_input_indices: gy = None + elif not must_reduce_y: + gy = grad else: gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy) return (gx, gy) - @ops.RegisterGradient("Sub") def _SubGrad(op, grad): """Gradient for Sub.""" - x = op.inputs[0] y = op.inputs[1] + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( + y): + return grad, None + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + x = op.inputs[0] if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, -grad - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) - return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), - array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy)) + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) + if skip_input_indices is not None and 0 in skip_input_indices: + gx = None + elif not must_reduce_x: + gx = grad + else: + gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) + if skip_input_indices is not None and 1 in skip_input_indices: + gy = None + elif not must_reduce_y: + gy = -grad + else: + gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy) + return (gx, gy) @ops.RegisterGradient("Mul") def _MulGrad(op, grad): """The gradient of scalar multiplication.""" - x = op.inputs[0] y = op.inputs[1] + skip_input_indices = None + try: + skip_input_indices = op.skip_input_indices + if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( + y): + return gen_math_ops.mul(grad, math_ops.conj(y)), None + except AttributeError: + # No gradient skipping, so do the full gradient computation + pass + x = op.inputs[0] if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) - sx = array_ops.shape(x) - sy = array_ops.shape(y) - rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + + (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( + SmartBroadcastGradientArgs(x, y, grad)) x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape( - math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx), - array_ops.reshape( - math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)) + if skip_input_indices is not None and 0 in skip_input_indices: + gx = None + elif not must_reduce_x: + gx = gen_math_ops.mul(grad, y) + else: + gx = array_ops.reshape( + math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx) + if skip_input_indices is not None and 1 in skip_input_indices: + gy = None + elif not must_reduce_y: + gy = gen_math_ops.mul(x, grad) + else: + gy = array_ops.reshape( + math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy) + return (gx, gy) @ops.RegisterGradient("MulNoNan")