Merge pull request #31899 from mrry/r1.15
[r1.15 cherrypick] Fix tf.gradients() performance regression
This commit is contained in:
commit
bd96595653
@ -69,7 +69,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs):
|
||||
|
||||
|
||||
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
||||
xs):
|
||||
xs_set):
|
||||
"""Initialize the pending count for ops between two lists of Operations.
|
||||
|
||||
'pending_count[op]' indicates the number of backprop inputs
|
||||
@ -83,7 +83,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
||||
these functions if they capture from_ops or any reachable ops. This is
|
||||
useful if to_ops occur in a function and from_ops are in an outer function
|
||||
or graph.
|
||||
xs: list of Tensors.
|
||||
xs_set: ObjectIdentitySet of Tensors.
|
||||
|
||||
Returns:
|
||||
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
|
||||
@ -113,7 +113,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
||||
between_op_list.append(op)
|
||||
# Clear the boolean so we won't add the inputs again.
|
||||
reached_ops.remove(op)
|
||||
for inp in _NonEagerInputs(op, xs):
|
||||
for inp in _NonEagerInputs(op, xs_set):
|
||||
queue.append(inp.op)
|
||||
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
|
||||
# between from_ops and to_ops
|
||||
@ -125,7 +125,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
||||
# Initialize pending count for between ops.
|
||||
pending_count = collections.defaultdict(int)
|
||||
for op in between_op_list:
|
||||
for x in _NonEagerInputs(op, xs):
|
||||
for x in _NonEagerInputs(op, xs_set):
|
||||
if x.op in between_ops:
|
||||
pending_count[x.op] += 1
|
||||
|
||||
@ -265,7 +265,7 @@ def _VerifyGeneratedGradients(grads, op):
|
||||
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
|
||||
|
||||
|
||||
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
||||
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
|
||||
"""The set of ops that terminate the gradient computation.
|
||||
|
||||
This computes the frontier of the forward graph *before* which backprop
|
||||
@ -281,7 +281,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
||||
from_ops: list of Operations.
|
||||
stop_gradient_ops: list of Operations never to backprop through.
|
||||
pending_count: mapping from operation to number of backprop inputs.
|
||||
xs: list of Tensors.
|
||||
xs_set: ObjectIdentitySet of Tensors.
|
||||
|
||||
Returns:
|
||||
The set of operations.
|
||||
@ -289,7 +289,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
||||
stop_ops = set()
|
||||
for op in from_ops:
|
||||
is_stop_op = True
|
||||
for inp in _NonEagerInputs(op, xs):
|
||||
for inp in _NonEagerInputs(op, xs_set):
|
||||
if pending_count[inp.op] > 0:
|
||||
is_stop_op = False
|
||||
break
|
||||
@ -369,7 +369,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
|
||||
return grad_fn()
|
||||
|
||||
|
||||
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
|
||||
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
|
||||
"""Raises an error if we backprop through a loop var."""
|
||||
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
|
||||
# message.
|
||||
@ -383,7 +383,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
|
||||
if curr_op in from_ops:
|
||||
target_op = curr_op
|
||||
break
|
||||
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
|
||||
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
|
||||
assert target_op
|
||||
raise ValueError(
|
||||
"Cannot compute gradient inside while loop with respect to op '%s'. "
|
||||
@ -425,7 +425,7 @@ def _MaybeCaptured(t):
|
||||
return t
|
||||
|
||||
|
||||
def _NonEagerInputs(op, xs):
|
||||
def _NonEagerInputs(op, xs_set):
|
||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||
|
||||
Does not return any captured EagerTensors, i.e., the number of tensors
|
||||
@ -433,29 +433,28 @@ def _NonEagerInputs(op, xs):
|
||||
|
||||
Args:
|
||||
op: Operation
|
||||
xs: list of Tensors we are differentiating w.r.t.
|
||||
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
|
||||
|
||||
Returns:
|
||||
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||
is in a FuncGraph and has captured inputs.
|
||||
"""
|
||||
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
|
||||
return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
|
||||
|
||||
|
||||
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
||||
# _GradientsHelper a class with xs as a member variable.
|
||||
def _Inputs(op, xs):
|
||||
def _Inputs(op, xs_set):
|
||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||
|
||||
Args:
|
||||
op: Operation
|
||||
xs: list of Tensors we are differentiating w.r.t.
|
||||
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
|
||||
|
||||
Returns:
|
||||
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||
is in a FuncGraph and has captured inputs.
|
||||
"""
|
||||
tensors = object_identity.ObjectIdentitySet(xs)
|
||||
if _IsFunction(op.graph): # pylint: disable=protected-access
|
||||
inputs = []
|
||||
for t in op.inputs:
|
||||
@ -464,7 +463,7 @@ def _Inputs(op, xs):
|
||||
# even if it's a function input for a captured value, whereas usually we'd
|
||||
# like to traverse through these closures as if the captured value was the
|
||||
# direct input to op.
|
||||
if t not in tensors:
|
||||
if t not in xs_set:
|
||||
t = _MaybeCaptured(t)
|
||||
inputs.append(t)
|
||||
return inputs
|
||||
@ -546,6 +545,7 @@ def _GradientsHelper(ys,
|
||||
]
|
||||
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
|
||||
xs, name="x", as_ref=True)
|
||||
xs_set = object_identity.ObjectIdentitySet(xs)
|
||||
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
|
||||
gradient_uid)
|
||||
|
||||
@ -562,7 +562,7 @@ def _GradientsHelper(ys,
|
||||
from_ops = [t.op for t in xs]
|
||||
stop_gradient_ops = [t.op for t in stop_gradients]
|
||||
reachable_to_ops, pending_count, loop_state = _PendingCount(
|
||||
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
|
||||
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
|
||||
|
||||
# Iterate over the collected ops.
|
||||
#
|
||||
@ -596,7 +596,7 @@ def _GradientsHelper(ys,
|
||||
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
|
||||
queue.append(y.op)
|
||||
|
||||
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
|
||||
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
|
||||
while queue:
|
||||
# generate gradient subgraph for op.
|
||||
op = queue.popleft()
|
||||
@ -649,7 +649,7 @@ def _GradientsHelper(ys,
|
||||
op._control_flow_context.IsWhileContext() and
|
||||
op._control_flow_context ==
|
||||
ops.get_default_graph()._get_control_flow_context()):
|
||||
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
|
||||
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if (grad_fn or is_func_call) and has_out_grads:
|
||||
@ -696,10 +696,10 @@ def _GradientsHelper(ys,
|
||||
else:
|
||||
# If no grad_fn is defined or none of out_grads is available,
|
||||
# just propagate a list of None backwards.
|
||||
in_grads = [None] * len(_Inputs(op, xs))
|
||||
in_grads = [None] * len(_Inputs(op, xs_set))
|
||||
# Note: we don't filter out eager inputs here because the inputs need to
|
||||
# line up with in_grads.
|
||||
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
|
||||
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
|
||||
if in_grad is not None:
|
||||
if (isinstance(in_grad, ops.Tensor) and
|
||||
t_in.dtype != dtypes.resource):
|
||||
@ -719,7 +719,7 @@ def _GradientsHelper(ys,
|
||||
|
||||
# Update pending count for the inputs of op and enqueue ready ops.
|
||||
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
||||
xs)
|
||||
xs_set)
|
||||
|
||||
if loop_state:
|
||||
loop_state.PostProcessing()
|
||||
@ -739,9 +739,9 @@ def _HasAnyNotNoneGrads(grads, op):
|
||||
|
||||
|
||||
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
||||
xs):
|
||||
xs_set):
|
||||
"""Update pending count for the inputs of op and enqueue ready ops."""
|
||||
for x in _NonEagerInputs(op, xs):
|
||||
for x in _NonEagerInputs(op, xs_set):
|
||||
pending_count[x.op] -= 1
|
||||
ready = (pending_count[x.op] == 0)
|
||||
if loop_state and not ready:
|
||||
|
Loading…
Reference in New Issue
Block a user