Optimize gradient calculation for +, *, and / in graph mode.
This change adds a graph-level cache for the results of `broadcast_gradient_args()` when the inputs have statically known shapes. Using this cache, it can avoid generating unnecessary ops, which shrinks the graph and improves startup time. PiperOrigin-RevId: 260239456
This commit is contained in:
parent
6ad7d2ac03
commit
04c51715d5
@ -2839,6 +2839,10 @@ class Graph(object):
|
||||
self._add_control_dependencies = False
|
||||
# Cache for OpDef protobufs retrieved via the C API.
|
||||
self._op_def_cache = {}
|
||||
# Cache for constant results of `broadcast_gradient_args()`. The keys are
|
||||
# tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
|
||||
# values are tuples of reduction indices: (rx, ry).
|
||||
self._bcast_grad_args_cache = {}
|
||||
|
||||
# TODO(skyewm): fold as much of the above as possible into the C
|
||||
# implementation
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -52,6 +53,79 @@ def _ArgMinGrad(op, grad):
|
||||
ops.NotDifferentiable("EuclideanNorm")
|
||||
|
||||
|
||||
def SmartBroadcastGradientArgs(x, y, grad):
|
||||
"""Optimized version of `broadcast_gradient_args` that caches results.
|
||||
|
||||
This implementation avoids creating `broadcast_gradient_args` ops in the case
|
||||
that the input shapes are fully defined, and provides hints to the calling
|
||||
code that can be used to avoid creating reduction and reshaping ops.
|
||||
|
||||
Args:
|
||||
x: The left input tensor to a broadcasting binary op.
|
||||
y: The right input tensor to a broadcasting binary op.
|
||||
grad: The incoming gradient tensor for a broadcasting binary op.
|
||||
|
||||
Returns:
|
||||
A pair of tuples, containing:
|
||||
* A 3-tuple of broadcast information for x, containing:
|
||||
* The shape of x (as a tuple or Tensor).
|
||||
* The reduction indices for x (as a tuple or Tensor).
|
||||
* A boolean, which if True, indicates that x's shape differs from grad's
|
||||
shape (and so x's gradient must be reduced and/or reshaped).
|
||||
* A 3-tuple of broadcast information for y, containing the respective
|
||||
details for y.
|
||||
"""
|
||||
# NOTE: It may be productive to apply these optimizations in the eager case
|
||||
# as well.
|
||||
if context.executing_eagerly() or not (
|
||||
isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
|
||||
and isinstance(grad, ops.Tensor)):
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
return (sx, rx, True), (sy, ry, True)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
x_shape_tuple = x._shape_tuple()
|
||||
y_shape_tuple = y._shape_tuple()
|
||||
grad_shape_tuple = grad._shape_tuple()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if (x_shape_tuple is None or None in x_shape_tuple or
|
||||
y_shape_tuple is None or None in y_shape_tuple):
|
||||
sx = array_ops.shape_internal(x, optimize=False)
|
||||
sy = array_ops.shape_internal(y, optimize=False)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
return (sx, rx, True), (sy, ry, True)
|
||||
|
||||
x_needs_reduction = x_shape_tuple != grad_shape_tuple
|
||||
y_needs_reduction = y_shape_tuple != grad_shape_tuple
|
||||
|
||||
# Get the default graph rather than relying on `x.graph`, `y.graph`, or
|
||||
# `grad.graph`, because these may be eager tensors.
|
||||
g = ops.get_default_graph()
|
||||
|
||||
try:
|
||||
rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access
|
||||
return (x_shape_tuple, rx, x_needs_reduction), (
|
||||
y_shape_tuple, ry, y_needs_reduction)
|
||||
except KeyError:
|
||||
rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
|
||||
# TODO(mrry): If this becomes a bottleneck, add a multi-output version of
|
||||
# `TF_TryEvaluateConstant()`.
|
||||
rx_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
|
||||
rx.graph._c_graph, rx._as_tf_output())) # pylint: disable=protected-access
|
||||
assert rx_value is not None
|
||||
ry_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
|
||||
ry.graph._c_graph, ry._as_tf_output())) # pylint: disable=protected-access
|
||||
assert ry_value is not None
|
||||
g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access
|
||||
rx_value, ry_value)
|
||||
|
||||
return (x_shape_tuple, rx_value, x_needs_reduction), (
|
||||
y_shape_tuple, ry_value, y_needs_reduction)
|
||||
|
||||
|
||||
_empty_tuple = ()
|
||||
|
||||
|
||||
@ -1000,55 +1074,96 @@ def _AddGrad(op, grad):
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||
return grad, grad
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||
gx = None
|
||||
elif not must_reduce_x:
|
||||
gx = grad
|
||||
else:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
|
||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||
gy = None
|
||||
elif not must_reduce_y:
|
||||
gy = grad
|
||||
else:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
|
||||
return (gx, gy)
|
||||
|
||||
|
||||
|
||||
@ops.RegisterGradient("Sub")
|
||||
def _SubGrad(op, grad):
|
||||
"""Gradient for Sub."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||
y):
|
||||
return grad, None
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
x = op.inputs[0]
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||
return grad, -grad
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
|
||||
array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||
gx = None
|
||||
elif not must_reduce_x:
|
||||
gx = grad
|
||||
else:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
|
||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||
gy = None
|
||||
elif not must_reduce_y:
|
||||
gy = -grad
|
||||
else:
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
|
||||
return (gx, gy)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Mul")
|
||||
def _MulGrad(op, grad):
|
||||
"""The gradient of scalar multiplication."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
skip_input_indices = None
|
||||
try:
|
||||
skip_input_indices = op.skip_input_indices
|
||||
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||
y):
|
||||
return gen_math_ops.mul(grad, math_ops.conj(y)), None
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
pass
|
||||
x = op.inputs[0]
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad) and
|
||||
grad.dtype in (dtypes.int32, dtypes.float32)):
|
||||
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
|
||||
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
|
||||
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||
SmartBroadcastGradientArgs(x, y, grad))
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
|
||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||
gx = None
|
||||
elif not must_reduce_x:
|
||||
gx = gen_math_ops.mul(grad, y)
|
||||
else:
|
||||
gx = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
|
||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||
gy = None
|
||||
elif not must_reduce_y:
|
||||
gy = gen_math_ops.mul(x, grad)
|
||||
else:
|
||||
gy = array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
|
||||
return (gx, gy)
|
||||
|
||||
|
||||
@ops.RegisterGradient("MulNoNan")
|
||||
|
Loading…
x
Reference in New Issue
Block a user