Merge pull request #31899 from mrry/r1.15

[r1.15 cherrypick] Fix tf.gradients() performance regression
This commit is contained in:
Alexandre Passos 2019-08-22 10:40:40 -07:00 committed by GitHub
commit bd96595653
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,7 +69,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs):
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
xs): xs_set):
"""Initialize the pending count for ops between two lists of Operations. """Initialize the pending count for ops between two lists of Operations.
'pending_count[op]' indicates the number of backprop inputs 'pending_count[op]' indicates the number of backprop inputs
@ -83,7 +83,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
these functions if they capture from_ops or any reachable ops. This is these functions if they capture from_ops or any reachable ops. This is
useful if to_ops occur in a function and from_ops are in an outer function useful if to_ops occur in a function and from_ops are in an outer function
or graph. or graph.
xs: list of Tensors. xs_set: ObjectIdentitySet of Tensors.
Returns: Returns:
A tuple containing: (1) the subset of to_ops reachable from from_ops by a A tuple containing: (1) the subset of to_ops reachable from from_ops by a
@ -113,7 +113,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
between_op_list.append(op) between_op_list.append(op)
# Clear the boolean so we won't add the inputs again. # Clear the boolean so we won't add the inputs again.
reached_ops.remove(op) reached_ops.remove(op)
for inp in _NonEagerInputs(op, xs): for inp in _NonEagerInputs(op, xs_set):
queue.append(inp.op) queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors # X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops # between from_ops and to_ops
@ -125,7 +125,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
# Initialize pending count for between ops. # Initialize pending count for between ops.
pending_count = collections.defaultdict(int) pending_count = collections.defaultdict(int)
for op in between_op_list: for op in between_op_list:
for x in _NonEagerInputs(op, xs): for x in _NonEagerInputs(op, xs_set):
if x.op in between_ops: if x.op in between_ops:
pending_count[x.op] += 1 pending_count[x.op] += 1
@ -265,7 +265,7 @@ def _VerifyGeneratedGradients(grads, op):
"inputs %d" % (len(grads), op.node_def, len(op.inputs))) "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
"""The set of ops that terminate the gradient computation. """The set of ops that terminate the gradient computation.
This computes the frontier of the forward graph *before* which backprop This computes the frontier of the forward graph *before* which backprop
@ -281,7 +281,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
from_ops: list of Operations. from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through. stop_gradient_ops: list of Operations never to backprop through.
pending_count: mapping from operation to number of backprop inputs. pending_count: mapping from operation to number of backprop inputs.
xs: list of Tensors. xs_set: ObjectIdentitySet of Tensors.
Returns: Returns:
The set of operations. The set of operations.
@ -289,7 +289,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
stop_ops = set() stop_ops = set()
for op in from_ops: for op in from_ops:
is_stop_op = True is_stop_op = True
for inp in _NonEagerInputs(op, xs): for inp in _NonEagerInputs(op, xs_set):
if pending_count[inp.op] > 0: if pending_count[inp.op] > 0:
is_stop_op = False is_stop_op = False
break break
@ -369,7 +369,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn() return grad_fn()
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
"""Raises an error if we backprop through a loop var.""" """Raises an error if we backprop through a loop var."""
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
# message. # message.
@ -383,7 +383,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
if curr_op in from_ops: if curr_op in from_ops:
target_op = curr_op target_op = curr_op
break break
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs)) queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
assert target_op assert target_op
raise ValueError( raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. " "Cannot compute gradient inside while loop with respect to op '%s'. "
@ -425,7 +425,7 @@ def _MaybeCaptured(t):
return t return t
def _NonEagerInputs(op, xs): def _NonEagerInputs(op, xs_set):
"""Returns the inputs of op, crossing closure boundaries where necessary. """Returns the inputs of op, crossing closure boundaries where necessary.
Does not return any captured EagerTensors, i.e., the number of tensors Does not return any captured EagerTensors, i.e., the number of tensors
@ -433,29 +433,28 @@ def _NonEagerInputs(op, xs):
Args: Args:
op: Operation op: Operation
xs: list of Tensors we are differentiating w.r.t. xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
Returns: Returns:
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
is in a FuncGraph and has captured inputs. is in a FuncGraph and has captured inputs.
""" """
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)] return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making # TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable. # _GradientsHelper a class with xs as a member variable.
def _Inputs(op, xs): def _Inputs(op, xs_set):
"""Returns the inputs of op, crossing closure boundaries where necessary. """Returns the inputs of op, crossing closure boundaries where necessary.
Args: Args:
op: Operation op: Operation
xs: list of Tensors we are differentiating w.r.t. xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
Returns: Returns:
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
is in a FuncGraph and has captured inputs. is in a FuncGraph and has captured inputs.
""" """
tensors = object_identity.ObjectIdentitySet(xs)
if _IsFunction(op.graph): # pylint: disable=protected-access if _IsFunction(op.graph): # pylint: disable=protected-access
inputs = [] inputs = []
for t in op.inputs: for t in op.inputs:
@ -464,7 +463,7 @@ def _Inputs(op, xs):
# even if it's a function input for a captured value, whereas usually we'd # even if it's a function input for a captured value, whereas usually we'd
# like to traverse through these closures as if the captured value was the # like to traverse through these closures as if the captured value was the
# direct input to op. # direct input to op.
if t not in tensors: if t not in xs_set:
t = _MaybeCaptured(t) t = _MaybeCaptured(t)
inputs.append(t) inputs.append(t)
return inputs return inputs
@ -546,6 +545,7 @@ def _GradientsHelper(ys,
] ]
xs = ops.internal_convert_n_to_tensor_or_indexed_slices( xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
xs, name="x", as_ref=True) xs, name="x", as_ref=True)
xs_set = object_identity.ObjectIdentitySet(xs)
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
gradient_uid) gradient_uid)
@ -562,7 +562,7 @@ def _GradientsHelper(ys,
from_ops = [t.op for t in xs] from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients] stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount( reachable_to_ops, pending_count, loop_state = _PendingCount(
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs) to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
# Iterate over the collected ops. # Iterate over the collected ops.
# #
@ -596,7 +596,7 @@ def _GradientsHelper(ys,
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op) queue.append(y.op)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs) stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
while queue: while queue:
# generate gradient subgraph for op. # generate gradient subgraph for op.
op = queue.popleft() op = queue.popleft()
@ -649,7 +649,7 @@ def _GradientsHelper(ys,
op._control_flow_context.IsWhileContext() and op._control_flow_context.IsWhileContext() and
op._control_flow_context == op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()): ops.get_default_graph()._get_control_flow_context()):
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs) _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
# pylint: enable=protected-access # pylint: enable=protected-access
if (grad_fn or is_func_call) and has_out_grads: if (grad_fn or is_func_call) and has_out_grads:
@ -696,10 +696,10 @@ def _GradientsHelper(ys,
else: else:
# If no grad_fn is defined or none of out_grads is available, # If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards. # just propagate a list of None backwards.
in_grads = [None] * len(_Inputs(op, xs)) in_grads = [None] * len(_Inputs(op, xs_set))
# Note: we don't filter out eager inputs here because the inputs need to # Note: we don't filter out eager inputs here because the inputs need to
# line up with in_grads. # line up with in_grads.
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
if in_grad is not None: if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource): t_in.dtype != dtypes.resource):
@ -719,7 +719,7 @@ def _GradientsHelper(ys,
# Update pending count for the inputs of op and enqueue ready ops. # Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs) xs_set)
if loop_state: if loop_state:
loop_state.PostProcessing() loop_state.PostProcessing()
@ -739,9 +739,9 @@ def _HasAnyNotNoneGrads(grads, op):
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs): xs_set):
"""Update pending count for the inputs of op and enqueue ready ops.""" """Update pending count for the inputs of op and enqueue ready ops."""
for x in _NonEagerInputs(op, xs): for x in _NonEagerInputs(op, xs_set):
pending_count[x.op] -= 1 pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0) ready = (pending_count[x.op] == 0)
if loop_state and not ready: if loop_state and not ready: