Optimize gradient calculation for +, *, and / in graph mode.

This change adds a graph-level cache for the results of `broadcast_gradient_args()`
when the inputs have statically known shapes. Using this cache, it can avoid
generating unnecessary ops, which shrinks the graph and improves startup time.

PiperOrigin-RevId: 260239456
This commit is contained in:
Derek Murray 2019-07-26 16:54:49 -07:00 committed by TensorFlower Gardener
parent 6ad7d2ac03
commit 04c51715d5
2 changed files with 137 additions and 18 deletions

View File

@ -2839,6 +2839,10 @@ class Graph(object):
self._add_control_dependencies = False
# Cache for OpDef protobufs retrieved via the C API.
self._op_def_cache = {}
# Cache for constant results of `broadcast_gradient_args()`. The keys are
# tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
# values are tuples of reduction indices: (rx, ry).
self._bcast_grad_args_cache = {}
# TODO(skyewm): fold as much of the above as possible into the C
# implementation

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -52,6 +53,79 @@ def _ArgMinGrad(op, grad):
ops.NotDifferentiable("EuclideanNorm")
def SmartBroadcastGradientArgs(x, y, grad):
"""Optimized version of `broadcast_gradient_args` that caches results.
This implementation avoids creating `broadcast_gradient_args` ops in the case
that the input shapes are fully defined, and provides hints to the calling
code that can be used to avoid creating reduction and reshaping ops.
Args:
x: The left input tensor to a broadcasting binary op.
y: The right input tensor to a broadcasting binary op.
grad: The incoming gradient tensor for a broadcasting binary op.
Returns:
A pair of tuples, containing:
* A 3-tuple of broadcast information for x, containing:
* The shape of x (as a tuple or Tensor).
* The reduction indices for x (as a tuple or Tensor).
* A boolean, which if True, indicates that x's shape differs from grad's
shape (and so x's gradient must be reduced and/or reshaped).
* A 3-tuple of broadcast information for y, containing the respective
details for y.
"""
# NOTE: It may be productive to apply these optimizations in the eager case
# as well.
if context.executing_eagerly() or not (
isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
and isinstance(grad, ops.Tensor)):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (sx, rx, True), (sy, ry, True)
# pylint: disable=protected-access
x_shape_tuple = x._shape_tuple()
y_shape_tuple = y._shape_tuple()
grad_shape_tuple = grad._shape_tuple()
# pylint: enable=protected-access
if (x_shape_tuple is None or None in x_shape_tuple or
y_shape_tuple is None or None in y_shape_tuple):
sx = array_ops.shape_internal(x, optimize=False)
sy = array_ops.shape_internal(y, optimize=False)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (sx, rx, True), (sy, ry, True)
x_needs_reduction = x_shape_tuple != grad_shape_tuple
y_needs_reduction = y_shape_tuple != grad_shape_tuple
# Get the default graph rather than relying on `x.graph`, `y.graph`, or
# `grad.graph`, because these may be eager tensors.
g = ops.get_default_graph()
try:
rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access
return (x_shape_tuple, rx, x_needs_reduction), (
y_shape_tuple, ry, y_needs_reduction)
except KeyError:
rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
# TODO(mrry): If this becomes a bottleneck, add a multi-output version of
# `TF_TryEvaluateConstant()`.
rx_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
rx.graph._c_graph, rx._as_tf_output())) # pylint: disable=protected-access
assert rx_value is not None
ry_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
ry.graph._c_graph, ry._as_tf_output())) # pylint: disable=protected-access
assert ry_value is not None
g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access
rx_value, ry_value)
return (x_shape_tuple, rx_value, x_needs_reduction), (
y_shape_tuple, ry_value, y_needs_reduction)
_empty_tuple = ()
@ -1000,55 +1074,96 @@ def _AddGrad(op, grad):
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = grad
else:
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = grad
else:
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Sub")
def _SubGrad(op, grad):
"""Gradient for Sub."""
x = op.inputs[0]
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return grad, None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, -grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = grad
else:
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = -grad
else:
gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Mul")
def _MulGrad(op, grad):
"""The gradient of scalar multiplication."""
x = op.inputs[0]
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return gen_math_ops.mul(grad, math_ops.conj(y)), None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad) and
grad.dtype in (dtypes.int32, dtypes.float32)):
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = gen_math_ops.mul(grad, y)
else:
gx = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = gen_math_ops.mul(x, grad)
else:
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
return (gx, gy)
@ops.RegisterGradient("MulNoNan")