Optimize gradient calculation for +, *, and / in graph mode.
This change adds a graph-level cache for the results of `broadcast_gradient_args()` when the inputs have statically known shapes. Using this cache, it can avoid generating unnecessary ops, which shrinks the graph and improves startup time. PiperOrigin-RevId: 260239456
This commit is contained in:
parent
6ad7d2ac03
commit
04c51715d5
@ -2839,6 +2839,10 @@ class Graph(object):
|
|||||||
self._add_control_dependencies = False
|
self._add_control_dependencies = False
|
||||||
# Cache for OpDef protobufs retrieved via the C API.
|
# Cache for OpDef protobufs retrieved via the C API.
|
||||||
self._op_def_cache = {}
|
self._op_def_cache = {}
|
||||||
|
# Cache for constant results of `broadcast_gradient_args()`. The keys are
|
||||||
|
# tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
|
||||||
|
# values are tuples of reduction indices: (rx, ry).
|
||||||
|
self._bcast_grad_args_cache = {}
|
||||||
|
|
||||||
# TODO(skyewm): fold as much of the above as possible into the C
|
# TODO(skyewm): fold as much of the above as possible into the C
|
||||||
# implementation
|
# implementation
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tensorflow as c_api
|
||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -52,6 +53,79 @@ def _ArgMinGrad(op, grad):
|
|||||||
ops.NotDifferentiable("EuclideanNorm")
|
ops.NotDifferentiable("EuclideanNorm")
|
||||||
|
|
||||||
|
|
||||||
|
def SmartBroadcastGradientArgs(x, y, grad):
|
||||||
|
"""Optimized version of `broadcast_gradient_args` that caches results.
|
||||||
|
|
||||||
|
This implementation avoids creating `broadcast_gradient_args` ops in the case
|
||||||
|
that the input shapes are fully defined, and provides hints to the calling
|
||||||
|
code that can be used to avoid creating reduction and reshaping ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The left input tensor to a broadcasting binary op.
|
||||||
|
y: The right input tensor to a broadcasting binary op.
|
||||||
|
grad: The incoming gradient tensor for a broadcasting binary op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair of tuples, containing:
|
||||||
|
* A 3-tuple of broadcast information for x, containing:
|
||||||
|
* The shape of x (as a tuple or Tensor).
|
||||||
|
* The reduction indices for x (as a tuple or Tensor).
|
||||||
|
* A boolean, which if True, indicates that x's shape differs from grad's
|
||||||
|
shape (and so x's gradient must be reduced and/or reshaped).
|
||||||
|
* A 3-tuple of broadcast information for y, containing the respective
|
||||||
|
details for y.
|
||||||
|
"""
|
||||||
|
# NOTE: It may be productive to apply these optimizations in the eager case
|
||||||
|
# as well.
|
||||||
|
if context.executing_eagerly() or not (
|
||||||
|
isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
|
||||||
|
and isinstance(grad, ops.Tensor)):
|
||||||
|
sx = array_ops.shape(x)
|
||||||
|
sy = array_ops.shape(y)
|
||||||
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||||
|
return (sx, rx, True), (sy, ry, True)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
x_shape_tuple = x._shape_tuple()
|
||||||
|
y_shape_tuple = y._shape_tuple()
|
||||||
|
grad_shape_tuple = grad._shape_tuple()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
if (x_shape_tuple is None or None in x_shape_tuple or
|
||||||
|
y_shape_tuple is None or None in y_shape_tuple):
|
||||||
|
sx = array_ops.shape_internal(x, optimize=False)
|
||||||
|
sy = array_ops.shape_internal(y, optimize=False)
|
||||||
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||||
|
return (sx, rx, True), (sy, ry, True)
|
||||||
|
|
||||||
|
x_needs_reduction = x_shape_tuple != grad_shape_tuple
|
||||||
|
y_needs_reduction = y_shape_tuple != grad_shape_tuple
|
||||||
|
|
||||||
|
# Get the default graph rather than relying on `x.graph`, `y.graph`, or
|
||||||
|
# `grad.graph`, because these may be eager tensors.
|
||||||
|
g = ops.get_default_graph()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access
|
||||||
|
return (x_shape_tuple, rx, x_needs_reduction), (
|
||||||
|
y_shape_tuple, ry, y_needs_reduction)
|
||||||
|
except KeyError:
|
||||||
|
rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
|
||||||
|
# TODO(mrry): If this becomes a bottleneck, add a multi-output version of
|
||||||
|
# `TF_TryEvaluateConstant()`.
|
||||||
|
rx_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
|
||||||
|
rx.graph._c_graph, rx._as_tf_output())) # pylint: disable=protected-access
|
||||||
|
assert rx_value is not None
|
||||||
|
ry_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
|
||||||
|
ry.graph._c_graph, ry._as_tf_output())) # pylint: disable=protected-access
|
||||||
|
assert ry_value is not None
|
||||||
|
g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access
|
||||||
|
rx_value, ry_value)
|
||||||
|
|
||||||
|
return (x_shape_tuple, rx_value, x_needs_reduction), (
|
||||||
|
y_shape_tuple, ry_value, y_needs_reduction)
|
||||||
|
|
||||||
|
|
||||||
_empty_tuple = ()
|
_empty_tuple = ()
|
||||||
|
|
||||||
|
|
||||||
@ -1000,55 +1074,96 @@ def _AddGrad(op, grad):
|
|||||||
if (isinstance(grad, ops.Tensor) and
|
if (isinstance(grad, ops.Tensor) and
|
||||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||||
return grad, grad
|
return grad, grad
|
||||||
sx = array_ops.shape(x)
|
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||||
sy = array_ops.shape(y)
|
SmartBroadcastGradientArgs(x, y, grad))
|
||||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
|
||||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||||
gx = None
|
gx = None
|
||||||
|
elif not must_reduce_x:
|
||||||
|
gx = grad
|
||||||
else:
|
else:
|
||||||
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
|
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
|
||||||
if skip_input_indices is not None and 1 in skip_input_indices:
|
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||||
gy = None
|
gy = None
|
||||||
|
elif not must_reduce_y:
|
||||||
|
gy = grad
|
||||||
else:
|
else:
|
||||||
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
|
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
|
||||||
return (gx, gy)
|
return (gx, gy)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Sub")
|
@ops.RegisterGradient("Sub")
|
||||||
def _SubGrad(op, grad):
|
def _SubGrad(op, grad):
|
||||||
"""Gradient for Sub."""
|
"""Gradient for Sub."""
|
||||||
x = op.inputs[0]
|
|
||||||
y = op.inputs[1]
|
y = op.inputs[1]
|
||||||
|
skip_input_indices = None
|
||||||
|
try:
|
||||||
|
skip_input_indices = op.skip_input_indices
|
||||||
|
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||||
|
y):
|
||||||
|
return grad, None
|
||||||
|
except AttributeError:
|
||||||
|
# No gradient skipping, so do the full gradient computation
|
||||||
|
pass
|
||||||
|
x = op.inputs[0]
|
||||||
if (isinstance(grad, ops.Tensor) and
|
if (isinstance(grad, ops.Tensor) and
|
||||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||||
return grad, -grad
|
return grad, -grad
|
||||||
sx = array_ops.shape(x)
|
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||||
sy = array_ops.shape(y)
|
SmartBroadcastGradientArgs(x, y, grad))
|
||||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||||
return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
|
gx = None
|
||||||
array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
|
elif not must_reduce_x:
|
||||||
|
gx = grad
|
||||||
|
else:
|
||||||
|
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
|
||||||
|
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||||
|
gy = None
|
||||||
|
elif not must_reduce_y:
|
||||||
|
gy = -grad
|
||||||
|
else:
|
||||||
|
gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
|
||||||
|
return (gx, gy)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Mul")
|
@ops.RegisterGradient("Mul")
|
||||||
def _MulGrad(op, grad):
|
def _MulGrad(op, grad):
|
||||||
"""The gradient of scalar multiplication."""
|
"""The gradient of scalar multiplication."""
|
||||||
x = op.inputs[0]
|
|
||||||
y = op.inputs[1]
|
y = op.inputs[1]
|
||||||
|
skip_input_indices = None
|
||||||
|
try:
|
||||||
|
skip_input_indices = op.skip_input_indices
|
||||||
|
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
|
||||||
|
y):
|
||||||
|
return gen_math_ops.mul(grad, math_ops.conj(y)), None
|
||||||
|
except AttributeError:
|
||||||
|
# No gradient skipping, so do the full gradient computation
|
||||||
|
pass
|
||||||
|
x = op.inputs[0]
|
||||||
if (isinstance(grad, ops.Tensor) and
|
if (isinstance(grad, ops.Tensor) and
|
||||||
_ShapesFullySpecifiedAndEqual(x, y, grad) and
|
_ShapesFullySpecifiedAndEqual(x, y, grad) and
|
||||||
grad.dtype in (dtypes.int32, dtypes.float32)):
|
grad.dtype in (dtypes.int32, dtypes.float32)):
|
||||||
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
|
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
|
||||||
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
||||||
sx = array_ops.shape(x)
|
|
||||||
sy = array_ops.shape(y)
|
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
|
||||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
SmartBroadcastGradientArgs(x, y, grad))
|
||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
return (array_ops.reshape(
|
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||||
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx),
|
gx = None
|
||||||
array_ops.reshape(
|
elif not must_reduce_x:
|
||||||
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
|
gx = gen_math_ops.mul(grad, y)
|
||||||
|
else:
|
||||||
|
gx = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
|
||||||
|
if skip_input_indices is not None and 1 in skip_input_indices:
|
||||||
|
gy = None
|
||||||
|
elif not must_reduce_y:
|
||||||
|
gy = gen_math_ops.mul(x, grad)
|
||||||
|
else:
|
||||||
|
gy = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
|
||||||
|
return (gx, gy)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("MulNoNan")
|
@ops.RegisterGradient("MulNoNan")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user