Merge pull request #31899 from mrry/r1.15
[r1.15 cherrypick] Fix tf.gradients() performance regression
This commit is contained in:
commit
bd96595653
@ -69,7 +69,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs):
|
|||||||
|
|
||||||
|
|
||||||
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
||||||
xs):
|
xs_set):
|
||||||
"""Initialize the pending count for ops between two lists of Operations.
|
"""Initialize the pending count for ops between two lists of Operations.
|
||||||
|
|
||||||
'pending_count[op]' indicates the number of backprop inputs
|
'pending_count[op]' indicates the number of backprop inputs
|
||||||
@ -83,7 +83,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
|||||||
these functions if they capture from_ops or any reachable ops. This is
|
these functions if they capture from_ops or any reachable ops. This is
|
||||||
useful if to_ops occur in a function and from_ops are in an outer function
|
useful if to_ops occur in a function and from_ops are in an outer function
|
||||||
or graph.
|
or graph.
|
||||||
xs: list of Tensors.
|
xs_set: ObjectIdentitySet of Tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
|
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
|
||||||
@ -113,7 +113,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
|||||||
between_op_list.append(op)
|
between_op_list.append(op)
|
||||||
# Clear the boolean so we won't add the inputs again.
|
# Clear the boolean so we won't add the inputs again.
|
||||||
reached_ops.remove(op)
|
reached_ops.remove(op)
|
||||||
for inp in _NonEagerInputs(op, xs):
|
for inp in _NonEagerInputs(op, xs_set):
|
||||||
queue.append(inp.op)
|
queue.append(inp.op)
|
||||||
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
|
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
|
||||||
# between from_ops and to_ops
|
# between from_ops and to_ops
|
||||||
@ -125,7 +125,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
|
|||||||
# Initialize pending count for between ops.
|
# Initialize pending count for between ops.
|
||||||
pending_count = collections.defaultdict(int)
|
pending_count = collections.defaultdict(int)
|
||||||
for op in between_op_list:
|
for op in between_op_list:
|
||||||
for x in _NonEagerInputs(op, xs):
|
for x in _NonEagerInputs(op, xs_set):
|
||||||
if x.op in between_ops:
|
if x.op in between_ops:
|
||||||
pending_count[x.op] += 1
|
pending_count[x.op] += 1
|
||||||
|
|
||||||
@ -265,7 +265,7 @@ def _VerifyGeneratedGradients(grads, op):
|
|||||||
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
|
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
|
||||||
|
|
||||||
|
|
||||||
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
|
||||||
"""The set of ops that terminate the gradient computation.
|
"""The set of ops that terminate the gradient computation.
|
||||||
|
|
||||||
This computes the frontier of the forward graph *before* which backprop
|
This computes the frontier of the forward graph *before* which backprop
|
||||||
@ -281,7 +281,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
|||||||
from_ops: list of Operations.
|
from_ops: list of Operations.
|
||||||
stop_gradient_ops: list of Operations never to backprop through.
|
stop_gradient_ops: list of Operations never to backprop through.
|
||||||
pending_count: mapping from operation to number of backprop inputs.
|
pending_count: mapping from operation to number of backprop inputs.
|
||||||
xs: list of Tensors.
|
xs_set: ObjectIdentitySet of Tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The set of operations.
|
The set of operations.
|
||||||
@ -289,7 +289,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
|
|||||||
stop_ops = set()
|
stop_ops = set()
|
||||||
for op in from_ops:
|
for op in from_ops:
|
||||||
is_stop_op = True
|
is_stop_op = True
|
||||||
for inp in _NonEagerInputs(op, xs):
|
for inp in _NonEagerInputs(op, xs_set):
|
||||||
if pending_count[inp.op] > 0:
|
if pending_count[inp.op] > 0:
|
||||||
is_stop_op = False
|
is_stop_op = False
|
||||||
break
|
break
|
||||||
@ -369,7 +369,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
|
|||||||
return grad_fn()
|
return grad_fn()
|
||||||
|
|
||||||
|
|
||||||
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
|
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
|
||||||
"""Raises an error if we backprop through a loop var."""
|
"""Raises an error if we backprop through a loop var."""
|
||||||
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
|
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
|
||||||
# message.
|
# message.
|
||||||
@ -383,7 +383,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
|
|||||||
if curr_op in from_ops:
|
if curr_op in from_ops:
|
||||||
target_op = curr_op
|
target_op = curr_op
|
||||||
break
|
break
|
||||||
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
|
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
|
||||||
assert target_op
|
assert target_op
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot compute gradient inside while loop with respect to op '%s'. "
|
"Cannot compute gradient inside while loop with respect to op '%s'. "
|
||||||
@ -425,7 +425,7 @@ def _MaybeCaptured(t):
|
|||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
def _NonEagerInputs(op, xs):
|
def _NonEagerInputs(op, xs_set):
|
||||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||||
|
|
||||||
Does not return any captured EagerTensors, i.e., the number of tensors
|
Does not return any captured EagerTensors, i.e., the number of tensors
|
||||||
@ -433,29 +433,28 @@ def _NonEagerInputs(op, xs):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: Operation
|
op: Operation
|
||||||
xs: list of Tensors we are differentiating w.r.t.
|
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||||
is in a FuncGraph and has captured inputs.
|
is in a FuncGraph and has captured inputs.
|
||||||
"""
|
"""
|
||||||
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
|
return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
|
||||||
|
|
||||||
|
|
||||||
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
||||||
# _GradientsHelper a class with xs as a member variable.
|
# _GradientsHelper a class with xs as a member variable.
|
||||||
def _Inputs(op, xs):
|
def _Inputs(op, xs_set):
|
||||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: Operation
|
op: Operation
|
||||||
xs: list of Tensors we are differentiating w.r.t.
|
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||||
is in a FuncGraph and has captured inputs.
|
is in a FuncGraph and has captured inputs.
|
||||||
"""
|
"""
|
||||||
tensors = object_identity.ObjectIdentitySet(xs)
|
|
||||||
if _IsFunction(op.graph): # pylint: disable=protected-access
|
if _IsFunction(op.graph): # pylint: disable=protected-access
|
||||||
inputs = []
|
inputs = []
|
||||||
for t in op.inputs:
|
for t in op.inputs:
|
||||||
@ -464,7 +463,7 @@ def _Inputs(op, xs):
|
|||||||
# even if it's a function input for a captured value, whereas usually we'd
|
# even if it's a function input for a captured value, whereas usually we'd
|
||||||
# like to traverse through these closures as if the captured value was the
|
# like to traverse through these closures as if the captured value was the
|
||||||
# direct input to op.
|
# direct input to op.
|
||||||
if t not in tensors:
|
if t not in xs_set:
|
||||||
t = _MaybeCaptured(t)
|
t = _MaybeCaptured(t)
|
||||||
inputs.append(t)
|
inputs.append(t)
|
||||||
return inputs
|
return inputs
|
||||||
@ -546,6 +545,7 @@ def _GradientsHelper(ys,
|
|||||||
]
|
]
|
||||||
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
|
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
|
||||||
xs, name="x", as_ref=True)
|
xs, name="x", as_ref=True)
|
||||||
|
xs_set = object_identity.ObjectIdentitySet(xs)
|
||||||
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
|
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
|
||||||
gradient_uid)
|
gradient_uid)
|
||||||
|
|
||||||
@ -562,7 +562,7 @@ def _GradientsHelper(ys,
|
|||||||
from_ops = [t.op for t in xs]
|
from_ops = [t.op for t in xs]
|
||||||
stop_gradient_ops = [t.op for t in stop_gradients]
|
stop_gradient_ops = [t.op for t in stop_gradients]
|
||||||
reachable_to_ops, pending_count, loop_state = _PendingCount(
|
reachable_to_ops, pending_count, loop_state = _PendingCount(
|
||||||
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
|
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
|
||||||
|
|
||||||
# Iterate over the collected ops.
|
# Iterate over the collected ops.
|
||||||
#
|
#
|
||||||
@ -596,7 +596,7 @@ def _GradientsHelper(ys,
|
|||||||
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
|
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
|
||||||
queue.append(y.op)
|
queue.append(y.op)
|
||||||
|
|
||||||
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
|
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
|
||||||
while queue:
|
while queue:
|
||||||
# generate gradient subgraph for op.
|
# generate gradient subgraph for op.
|
||||||
op = queue.popleft()
|
op = queue.popleft()
|
||||||
@ -649,7 +649,7 @@ def _GradientsHelper(ys,
|
|||||||
op._control_flow_context.IsWhileContext() and
|
op._control_flow_context.IsWhileContext() and
|
||||||
op._control_flow_context ==
|
op._control_flow_context ==
|
||||||
ops.get_default_graph()._get_control_flow_context()):
|
ops.get_default_graph()._get_control_flow_context()):
|
||||||
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
|
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
if (grad_fn or is_func_call) and has_out_grads:
|
if (grad_fn or is_func_call) and has_out_grads:
|
||||||
@ -696,10 +696,10 @@ def _GradientsHelper(ys,
|
|||||||
else:
|
else:
|
||||||
# If no grad_fn is defined or none of out_grads is available,
|
# If no grad_fn is defined or none of out_grads is available,
|
||||||
# just propagate a list of None backwards.
|
# just propagate a list of None backwards.
|
||||||
in_grads = [None] * len(_Inputs(op, xs))
|
in_grads = [None] * len(_Inputs(op, xs_set))
|
||||||
# Note: we don't filter out eager inputs here because the inputs need to
|
# Note: we don't filter out eager inputs here because the inputs need to
|
||||||
# line up with in_grads.
|
# line up with in_grads.
|
||||||
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
|
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
|
||||||
if in_grad is not None:
|
if in_grad is not None:
|
||||||
if (isinstance(in_grad, ops.Tensor) and
|
if (isinstance(in_grad, ops.Tensor) and
|
||||||
t_in.dtype != dtypes.resource):
|
t_in.dtype != dtypes.resource):
|
||||||
@ -719,7 +719,7 @@ def _GradientsHelper(ys,
|
|||||||
|
|
||||||
# Update pending count for the inputs of op and enqueue ready ops.
|
# Update pending count for the inputs of op and enqueue ready ops.
|
||||||
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
||||||
xs)
|
xs_set)
|
||||||
|
|
||||||
if loop_state:
|
if loop_state:
|
||||||
loop_state.PostProcessing()
|
loop_state.PostProcessing()
|
||||||
@ -739,9 +739,9 @@ def _HasAnyNotNoneGrads(grads, op):
|
|||||||
|
|
||||||
|
|
||||||
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
||||||
xs):
|
xs_set):
|
||||||
"""Update pending count for the inputs of op and enqueue ready ops."""
|
"""Update pending count for the inputs of op and enqueue ready ops."""
|
||||||
for x in _NonEagerInputs(op, xs):
|
for x in _NonEagerInputs(op, xs_set):
|
||||||
pending_count[x.op] -= 1
|
pending_count[x.op] -= 1
|
||||||
ready = (pending_count[x.op] == 0)
|
ready = (pending_count[x.op] == 0)
|
||||||
if loop_state and not ready:
|
if loop_state and not ready:
|
||||||
|
Loading…
Reference in New Issue
Block a user