From bddb975d1cb0ec172ed353fcbe8a49c1fea2db8e Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 22 Aug 2019 07:05:23 -0700 Subject: [PATCH] In _GradientsHelper() compute the ObjectIdentitySet(xs) once and reuse it. This avoids a potentially quadratic execution time in building the gradient graph, because we were previously creating the set multiple times for each op in the graph. PiperOrigin-RevId: 264826531 --- tensorflow/python/ops/gradients_util.py | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 0a0ea32c205..53520ec1075 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -69,7 +69,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs): def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, - xs): + xs_set): """Initialize the pending count for ops between two lists of Operations. 'pending_count[op]' indicates the number of backprop inputs @@ -83,7 +83,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, these functions if they capture from_ops or any reachable ops. This is useful if to_ops occur in a function and from_ops are in an outer function or graph. - xs: list of Tensors. + xs_set: ObjectIdentitySet of Tensors. Returns: A tuple containing: (1) the subset of to_ops reachable from from_ops by a @@ -113,7 +113,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, between_op_list.append(op) # Clear the boolean so we won't add the inputs again. reached_ops.remove(op) - for inp in _NonEagerInputs(op, xs): + for inp in _NonEagerInputs(op, xs_set): queue.append(inp.op) # X in between_ops iff X is on a path of zero or more backpropagatable tensors # between from_ops and to_ops @@ -125,7 +125,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, # Initialize pending count for between ops. pending_count = collections.defaultdict(int) for op in between_op_list: - for x in _NonEagerInputs(op, xs): + for x in _NonEagerInputs(op, xs_set): if x.op in between_ops: pending_count[x.op] += 1 @@ -265,7 +265,7 @@ def _VerifyGeneratedGradients(grads, op): "inputs %d" % (len(grads), op.node_def, len(op.inputs))) -def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): +def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set): """The set of ops that terminate the gradient computation. This computes the frontier of the forward graph *before* which backprop @@ -281,7 +281,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): from_ops: list of Operations. stop_gradient_ops: list of Operations never to backprop through. pending_count: mapping from operation to number of backprop inputs. - xs: list of Tensors. + xs_set: ObjectIdentitySet of Tensors. Returns: The set of operations. @@ -289,7 +289,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): stop_ops = set() for op in from_ops: is_stop_op = True - for inp in _NonEagerInputs(op, xs): + for inp in _NonEagerInputs(op, xs_set): if pending_count[inp.op] > 0: is_stop_op = False break @@ -369,7 +369,7 @@ def _MaybeCompile(scope, op, func, grad_fn): return grad_fn() -def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): +def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set): """Raises an error if we backprop through a loop var.""" # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error # message. @@ -383,7 +383,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): if curr_op in from_ops: target_op = curr_op break - queue.extend(t.op for t in _NonEagerInputs(curr_op, xs)) + queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set)) assert target_op raise ValueError( "Cannot compute gradient inside while loop with respect to op '%s'. " @@ -425,7 +425,7 @@ def _MaybeCaptured(t): return t -def _NonEagerInputs(op, xs): +def _NonEagerInputs(op, xs_set): """Returns the inputs of op, crossing closure boundaries where necessary. Does not return any captured EagerTensors, i.e., the number of tensors @@ -433,29 +433,28 @@ def _NonEagerInputs(op, xs): Args: op: Operation - xs: list of Tensors we are differentiating w.r.t. + xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. Returns: A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op is in a FuncGraph and has captured inputs. """ - return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)] + return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)] # TODO(skyewm): plumbing xs through everywhere is ugly, consider making # _GradientsHelper a class with xs as a member variable. -def _Inputs(op, xs): +def _Inputs(op, xs_set): """Returns the inputs of op, crossing closure boundaries where necessary. Args: op: Operation - xs: list of Tensors we are differentiating w.r.t. + xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. Returns: A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op is in a FuncGraph and has captured inputs. """ - tensors = object_identity.ObjectIdentitySet(xs) if _IsFunction(op.graph): # pylint: disable=protected-access inputs = [] for t in op.inputs: @@ -464,7 +463,7 @@ def _Inputs(op, xs): # even if it's a function input for a captured value, whereas usually we'd # like to traverse through these closures as if the captured value was the # direct input to op. - if t not in tensors: + if t not in xs_set: t = _MaybeCaptured(t) inputs.append(t) return inputs @@ -546,6 +545,7 @@ def _GradientsHelper(ys, ] xs = ops.internal_convert_n_to_tensor_or_indexed_slices( xs, name="x", as_ref=True) + xs_set = object_identity.ObjectIdentitySet(xs) grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid) @@ -562,7 +562,7 @@ def _GradientsHelper(ys, from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] reachable_to_ops, pending_count, loop_state = _PendingCount( - to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs) + to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set) # Iterate over the collected ops. # @@ -596,7 +596,7 @@ def _GradientsHelper(ys, _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) - stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs) + stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set) while queue: # generate gradient subgraph for op. op = queue.popleft() @@ -649,7 +649,7 @@ def _GradientsHelper(ys, op._control_flow_context.IsWhileContext() and op._control_flow_context == ops.get_default_graph()._get_control_flow_context()): - _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs) + _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set) # pylint: enable=protected-access if (grad_fn or is_func_call) and has_out_grads: @@ -696,10 +696,10 @@ def _GradientsHelper(ys, else: # If no grad_fn is defined or none of out_grads is available, # just propagate a list of None backwards. - in_grads = [None] * len(_Inputs(op, xs)) + in_grads = [None] * len(_Inputs(op, xs_set)) # Note: we don't filter out eager inputs here because the inputs need to # line up with in_grads. - for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): + for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): if in_grad is not None: if (isinstance(in_grad, ops.Tensor) and t_in.dtype != dtypes.resource): @@ -719,7 +719,7 @@ def _GradientsHelper(ys, # Update pending count for the inputs of op and enqueue ready ops. _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, - xs) + xs_set) if loop_state: loop_state.PostProcessing() @@ -739,9 +739,9 @@ def _HasAnyNotNoneGrads(grads, op): def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, - xs): + xs_set): """Update pending count for the inputs of op and enqueue ready ops.""" - for x in _NonEagerInputs(op, xs): + for x in _NonEagerInputs(op, xs_set): pending_count[x.op] -= 1 ready = (pending_count[x.op] == 0) if loop_state and not ready: