diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 09ec4ee12b2..4069ef1c703 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -339,7 +339,9 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + ":imperative_grad", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", @@ -425,3 +427,9 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +py_library( + name = "imperative_grad", + srcs = ["imperative_grad.py"], + deps = [":tape"], +) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 1d729cc2e19..3c84cbbd6fa 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import functools import operator import threading @@ -28,6 +27,7 @@ import six from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute +from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,288 +36,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect -# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total -# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation -# so as to release the gradient tensor to save memory. -_MIN_AGGREGATE_COUNT = 4 -_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024 - -# Terminology: -# -# - op: a possibly composite operation, which has an entry in the tape -# - target: dy in dx/dy -# - source: dx in dx/dy -# - tensor: one of the many inputs or outputs of an operation -# -# Below here we do the gradient algorithm. It works as follows: -# -# First we filter the tape to just the subset of operations we want to -# differentiate. In the process of doing so we count how many times each Tensor -# is used as an input to an op (so we know when we're done computing gradients -# for that Tensor). We also count, for each tape entry, how many of its output -# Tensors need gradients to be computed (Tensors which are not used do not need -# any gradients to be computed). -# -# Finally, we start a backprop stack with a set of tape entries for which we -# have all gradients available. This set usually is a subset of the set of -# targets (not all since targets which have outputs in the tape will not have -# gradients available initially). -# -# Then we repeatedly pop an entry from the stack, run its backprop, and update -# the gradients of its inputs. Once we have computed all gradients for a single -# input we can mark this input as done, and this can trigger adding an entry to -# the stack if all outputs of that entry are now done. -# -# When the stack is empty we have gradients for all tensors we're interested in. - - -def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources): - """Filters the tape to only include relevant entries and counts tensor usages. - - Args: - target: the target to optimize. - tensor_to_op: Map from tensor id to key in op_to_entry that produced it. - op_to_entry: Map from op id to a tape.TapeEntry object - id_sources: the ids of the sources wrt the gradient is being taken. - - Returns: - usage counts (how many entries downstream from a tensor use it) - op_to_entry_map: entry map (a filtered tape, with only the relevant - entries), - missing: map from tensor id to how many downstream gradients still need - to be computed before this tensor's gradient can be computed. - """ - if isinstance(target, (ops.Tensor)): - tensor_stack = [ops.tensor_id(target)] - else: - tensor_stack = list([ops.tensor_id(x) for x in target]) - tensor_usage_counts = {} - o_to_e = {} # Copy of just the bits we need from op_to_entry - while tensor_stack: - t = tensor_stack.pop() - op = tensor_to_op.get(t, None) - # op is None if the tensor is a source (i.e. was watched directly) - if op is None or op in o_to_e: - continue - op_trace = op_to_entry[op] - o_to_e[op] = op_trace - for it in op_trace.input_ids: - if it in tensor_usage_counts: - tensor_usage_counts[it] += 1 - else: - tensor_usage_counts[it] = 1 - if it not in id_sources and it in tensor_to_op: - tensor_stack.append(it) - op_missing_tensor_counts = collections.defaultdict(int) - for t in tensor_usage_counts: - if t in tensor_to_op and tensor_to_op[t] is not None: - op_missing_tensor_counts[tensor_to_op[t]] += 1 - return tensor_usage_counts, o_to_e, op_missing_tensor_counts - - -def _initialize_backprop_stack(op_to_entry, op_missing_tensor): - """Returns the set of tape entries which are available for backprop.""" - ready_ops = [] - for op in op_to_entry: - if op not in op_missing_tensor: - ready_ops.append(op) - return ready_ops - - -def _initial_gradients(target, output_gradients, tensor_usage_counts): - """Computes the initial gradients for each Tensor.""" - # Initialize the backprop stack - gradients = collections.defaultdict(list) - if isinstance(target, ops.Tensor): - if output_gradients is not None: - output_gradient = output_gradients - else: - output_gradient = array_ops.ones_like(target) - gradients[ops.tensor_id(target)].append(output_gradient) - else: - for i, t in enumerate(target): - if ops.tensor_id(t) in tensor_usage_counts: - # Can't provide a gradient of something we're trying to differentiate - assert output_gradients is None or output_gradients[i] is None - else: - if output_gradients is None or output_gradients[i] is None: - out_grad = array_ops.ones_like(t) - else: - out_grad = output_gradients[i] - gradients[ops.tensor_id(t)].append(out_grad) - return gradients - - -@tf_contextlib.contextmanager -def _no_op(): - yield - - -def _aggregate_grads(gradients): - """Aggregate gradients from multiple sources. - - Args: - gradients: A list of 'Tensor' or 'IndexedSlices' gradients. - - Returns: - If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. - Otherwise returns an aggregated 'IndexedSlices'. - """ - assert gradients, "No gradients to aggregate" - - if len(gradients) == 1: - return gradients[0] - if all([isinstance(g, ops.Tensor) for g in gradients]): - return math_ops.add_n(gradients) - else: - assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) - for g in gradients]) - indexed_slices_list = [] - for grad in gradients: - # TODO(xpan): Support nested IndexedSlices and core IndexedSlices - if isinstance(grad, ops.Tensor): - indexed_slices = ops.IndexedSlices( - grad, - constant_op.constant(range(grad.shape[0])), - constant_op.constant(grad.shape.as_list())) - indexed_slices_list.append(indexed_slices) - else: - indexed_slices_list.append(grad) - - # Dense shapes from all gradients should be the same. - dense_shape = indexed_slices_list[0].dense_shape - # For simplicity now, always cast to int64. - indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64) - for x in indexed_slices_list], 0) - values = array_ops.concat([x.values for x in indexed_slices_list], 0) - return ops.IndexedSlices(values, indices, dense_shape) - - -def _add_new_grads(gradients, gradients_size, tid, grad): - """Adds a new gradient and maybe aggregate the gradients. - - Args: - gradients: A dict map from tensor id to list of gradients. - gradients_size: A dict map from tensor id to its total units. Might - not be initialized. - tid: Tensor id. - grad: New gradient for the `tid`, either a Tensor or IndexedSlices. - - Raises: - ValueError: if `grad` is neight Tensor nor IndexedSlices. - """ - tensor_grads = gradients[tid] - tensor_grads.append(grad) - if len(tensor_grads) < _MIN_AGGREGATE_COUNT: - return - elif tid not in gradients_size: - if isinstance(grad, ops.Tensor): - size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access - elif isinstance(grad, ops.IndexedSlices): - size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access - else: - raise ValueError("Unexpected gradient type: %s" % type(grad)) - gradients_size[tid] = size - else: - size = gradients_size[tid] - - # For simplicity, assume each element to be 4 bytes now. - if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES: - gradients[tid] = [_aggregate_grads(tensor_grads)] - - -def imperative_grad( - target, - sources, - output_gradients=None): - """Computes gradients from the imperatively defined tape on top of the stack. - - Works by filtering the tape, computing how many downstream usages are of each - tensor and entry, and repeatedly applying backward functions until we have - gradients for all sources. - - Args: - target: either a Tensor or list of Tensors to be differentiated. - sources: list of Tensors for which we want gradients - output_gradients: if not None, a list of gradient provided for each Target, - or None if we are to use the target's computed downstream gradient. - - Returns: - the gradient wrt each of the sources. - - Raises: - RuntimeError: if something goes wrong. - ValueError: if there is no sequence of differentiable operations connecting - a source and any target Tensor. This can happen either if the target is - not computed based on the source, if the tracing was set up incorrectly, - or if only non-differentiable functions of the source were used in the - computation of target. - """ - if not tape._tape_stack.stack: # pylint: disable=protected-access - raise RuntimeError("Computing a gradient with no tape present") - bp_tape = tape.pop_tape() - tensor_to_op, op_to_entry = bp_tape.export() - # This overwrites the op_to_entry variable, which will release all memory used - # to keep traces that are irrelevant to the gradient computation we're doing - # here. - id_sources = [ops.tensor_id(t) for t in sources] - tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop( - target, tensor_to_op, op_to_entry, id_sources) - ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor) - gradients = _initial_gradients(target, output_gradients, - tensor_usage_counts) - gradients_size = dict() - # Now exhaust the backprop stack - while ready_ops: - op = ready_ops.pop() - op_trace = op_to_entry.pop(op) - out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids] - for i in range(len(out_gradients)): - if out_gradients[i] is None: - # TODO(apassos) this should be in the right device - none_indices = _grad_fn_accepts_none_for_indices.get( - op_trace.op_type, None) - if none_indices is None or i not in none_indices: - out_gradients[i] = array_ops.zeros( - *op_trace.output_shape_and_dtype[i]) - else: - out_gradients[i] = _aggregate_grads(out_gradients[i]) - - in_gradients = op_trace.backward_function( - *(out_gradients + op_trace.side_outputs)) - in_gradients = ([in_gradients] - if isinstance(in_gradients, (ops.Tensor, - ops.IndexedSlices, - type(None))) - else in_gradients) - for i, t in enumerate(op_trace.input_ids): - if in_gradients[i] is not None: - _add_new_grads(gradients, gradients_size, t, in_gradients[i]) - if tensor_usage_counts.get(t, 0) > 0: - tensor_usage_counts[t] -= 1 - if (t in tensor_to_op - and tensor_usage_counts[t] == 0 - and t not in id_sources): - in_op = tensor_to_op[t] - if in_op is None: - continue - if op_missing_tensor.get(in_op, 0) > 0: - op_missing_tensor[in_op] -= 1 - if op_missing_tensor.get(in_op, 0) == 0: - ready_ops.append(in_op) - result = [] - for i, s in enumerate(sources): - g = gradients.get(ops.tensor_id(s), None) - if g is None: - result.append(None) - else: - result.append(_aggregate_grads(g)) - return result - _op_attr_type_cache = {} @@ -557,7 +279,7 @@ def _record_gradient(op_name, inputs, attrs, results, name): if _tracing: print("Gradient for", (name if name else op_name), "inputs", op_inputs, "output_grads", orig_outputs, "gradients", result) - return result + return nest.flatten(result) tape.record_operation(op_name, results, inputs, [], grad_fn) if _tracing: @@ -615,7 +337,9 @@ def implicit_val_and_grad(f): end_node = f(*args) variables = tape.top_tape_watched_variables() sources = [x.handle for x in variables] - grad = imperative_grad(end_node, sources) + grad = imperative_grad.imperative_grad(_default_vspace, + nest.flatten(end_node), + sources) return end_node, list(zip(grad, variables)) return grad_fn @@ -849,6 +573,96 @@ def val_and_grad_function(f, params=None): sources.append(args[i]) tape.watch(args[i]) result = f(*args) - return result, imperative_grad(result, sources, output_gradients=dy) + return result, imperative_grad.imperative_grad( + _default_vspace, nest.flatten(result), sources, + output_gradients=nest.flatten(dy) if dy is not None else None) return decorated + + +def _aggregate_grads(gradients): + """Aggregate gradients from multiple sources. + + Args: + gradients: A list of 'Tensor' or 'IndexedSlices' gradients. + + Returns: + If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. + Otherwise returns an aggregated 'IndexedSlices'. + """ + assert gradients, "No gradients to aggregate" + + if len(gradients) == 1: + return gradients[0] + if all([isinstance(g, ops.Tensor) for g in gradients]): + return math_ops.add_n(gradients) + else: + assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) + for g in gradients]) + indexed_slices_list = [] + for grad in gradients: + # TODO(xpan): Support nested IndexedSlices and core IndexedSlices + if isinstance(grad, ops.Tensor): + indexed_slices = ops.IndexedSlices( + grad, + constant_op.constant(range(grad.shape[0])), + constant_op.constant(grad.shape.as_list())) + indexed_slices_list.append(indexed_slices) + else: + indexed_slices_list.append(grad) + + # Dense shapes from all gradients should be the same. + dense_shape = indexed_slices_list[0].dense_shape + # For simplicity now, always cast to int64. + indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64) + for x in indexed_slices_list], 0) + values = array_ops.concat([x.values for x in indexed_slices_list], 0) + return ops.IndexedSlices(values, indices, dense_shape) + + +# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total +# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation +# so as to release the gradient tensor to save memory. +_MIN_AGGREGATE_COUNT = 4 +_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024 + + +def _add_new_grads(gradients, gradients_size, tid, grad): + """Adds a new gradient and maybe aggregate the gradients. + + Args: + gradients: A dict map from tensor id to list of gradients. + gradients_size: A dict map from tensor id to its total units. Might + not be initialized. + tid: Tensor id. + grad: New gradient for the `tid`, either a Tensor or IndexedSlices. + + Raises: + ValueError: if `grad` is neight Tensor nor IndexedSlices. + """ + tensor_grads = gradients[tid] + tensor_grads.append(grad) + if len(tensor_grads) < _MIN_AGGREGATE_COUNT: + return + elif tid not in gradients_size: + if isinstance(grad, ops.Tensor): + size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access + elif isinstance(grad, ops.IndexedSlices): + size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access + else: + raise ValueError("Unexpected gradient type: %s" % type(grad)) + gradients_size[tid] = size + else: + size = gradients_size[tid] + + # For simplicity, assume each element to be 4 bytes now. + if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES: + gradients[tid] = [_aggregate_grads(tensor_grads)] + + +_default_vspace = imperative_grad.VSpace( + add_new_grads_fn=_add_new_grads, + aggregate_fn=_aggregate_grads, + tensor_id=ops.tensor_id, + zeros=array_ops.zeros, + ones_like=array_ops.ones_like) diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index 67c9015bf0c..4360e53225c 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -78,7 +78,7 @@ def custom_gradient(f): # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): - return grad_fn(*outputs) + return nest.flatten(grad_fn(*outputs)) flat_result = nest.flatten(result) tape.record_operation( diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index cb70d23f068..6ffc914f73e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -88,7 +88,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], [], - lambda x: x) + lambda x: [x]) return captured_value diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py new file mode 100644 index 00000000000..b81f5bba145 --- /dev/null +++ b/tensorflow/python/eager/imperative_grad.py @@ -0,0 +1,227 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Code for backpropagation using the tape utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.eager import tape + + +# Terminology: +# +# - op: a possibly composite operation, which has an entry in the tape +# - target: dy in dx/dy +# - source: dx in dx/dy +# - tensor: one of the many inputs or outputs of an operation +# +# Below here we do the gradient algorithm. It works as follows: +# +# First we filter the tape to just the subset of operations we want to +# differentiate. In the process of doing so we count how many times each Tensor +# is used as an input to an op (so we know when we're done computing gradients +# for that Tensor). We also count, for each tape entry, how many of its output +# Tensors need gradients to be computed (Tensors which are not used do not need +# any gradients to be computed). +# +# Finally, we start a backprop stack with a set of tape entries for which we +# have all gradients available. This set usually is a subset of the set of +# targets (not all since targets which have outputs in the tape will not have +# gradients available initially). +# +# Then we repeatedly pop an entry from the stack, run its backprop, and update +# the gradients of its inputs. Once we have computed all gradients for a single +# input we can mark this input as done, and this can trigger adding an entry to +# the stack if all outputs of that entry are now done. +# +# When the stack is empty we have gradients for all tensors we're interested in. +def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources): + """Filters the tape to only include relevant entries and counts tensor usages. + + Args: + vspace: information about the space we're differentiating in. + target: the target to optimize. + tensor_to_op: Map from tensor id to key in op_to_entry that produced it. + op_to_entry: Map from op id to a tape.TapeEntry object + id_sources: the ids of the sources wrt the gradient is being taken. + + Returns: + usage counts (how many entries downstream from a tensor use it) + op_to_entry_map: entry map (a filtered tape, with only the relevant + entries), + missing: map from tensor id to how many downstream gradients still need + to be computed before this tensor's gradient can be computed. + """ + tensor_stack = [vspace.tensor_id(x) for x in target] + tensor_usage_counts = {} + o_to_e = {} # Copy of just the bits we need from op_to_entry + while tensor_stack: + t = tensor_stack.pop() + op = tensor_to_op.get(t, None) + # op is None if the tensor is a source (i.e. was watched directly) + if op is None or op in o_to_e: + continue + op_trace = op_to_entry[op] + o_to_e[op] = op_trace + for it in op_trace.input_ids: + if it in tensor_usage_counts: + tensor_usage_counts[it] += 1 + else: + tensor_usage_counts[it] = 1 + if it not in id_sources and it in tensor_to_op: + tensor_stack.append(it) + op_missing_tensor_counts = collections.defaultdict(int) + for t in tensor_usage_counts: + if t in tensor_to_op and tensor_to_op[t] is not None: + op_missing_tensor_counts[tensor_to_op[t]] += 1 + return tensor_usage_counts, o_to_e, op_missing_tensor_counts + + +def _initialize_backprop_stack(op_to_entry, op_missing_tensor): + """Returns the set of tape entries which are available for backprop.""" + ready_ops = [] + for op in op_to_entry: + if op not in op_missing_tensor: + ready_ops.append(op) + return ready_ops + + +def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts): + """Computes the initial gradients for each Tensor.""" + # Initialize the backprop stack + gradients = collections.defaultdict(list) + for i, t in enumerate(target): + if vspace.tensor_id(t) in tensor_usage_counts: + # Can't provide a gradient of something we're trying to differentiate + assert output_gradients is None or output_gradients[i] is None + else: + if output_gradients is None or output_gradients[i] is None: + out_grad = vspace.ones_like(t) + else: + out_grad = output_gradients[i] + gradients[vspace.tensor_id(t)].append(out_grad) + return gradients + + +VSpace = collections.namedtuple( + "VSpace", + ["add_new_grads_fn", "aggregate_fn", "tensor_id", "zeros", "ones_like"]) + + +def imperative_grad( + vspace, + target, + sources, + output_gradients=None): + """Computes gradients from the imperatively defined tape on top of the stack. + + Works by filtering the tape, computing how many downstream usages are of each + tensor and entry, and repeatedly applying backward functions until we have + gradients for all sources. + + Args: + vspace: the vector space in which to differentiate. + target: either a Tensor or list of Tensors to be differentiated. + sources: list of Tensors for which we want gradients + output_gradients: if not None, a list of gradient provided for each Target, + or None if we are to use the target's computed downstream gradient. + + Returns: + the gradient wrt each of the sources. + + Raises: + RuntimeError: if something goes wrong. + ValueError: if there is no sequence of differentiable operations connecting + a source and any target Tensor. This can happen either if the target is + not computed based on the source, if the tracing was set up incorrectly, + or if only non-differentiable functions of the source were used in the + computation of target. + """ + if not tape._tape_stack.stack: # pylint: disable=protected-access + raise RuntimeError("Computing a gradient with no tape present") + bp_tape = tape.pop_tape() + tensor_to_op, op_to_entry = bp_tape.export() + # This overwrites the op_to_entry variable, which will release all memory used + # to keep traces that are irrelevant to the gradient computation we're doing + # here. + id_sources = [vspace.tensor_id(t) for t in sources] + tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop( + vspace, target, tensor_to_op, op_to_entry, id_sources) + ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor) + gradients = _initial_gradients(vspace, target, output_gradients, + tensor_usage_counts) + gradients_size = dict() + # Now exhaust the backprop stack + while ready_ops: + op = ready_ops.pop() + op_trace = op_to_entry.pop(op) + out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids] + for i in range(len(out_gradients)): + if out_gradients[i] is None: + # TODO(apassos) this should be in the right device + none_indices = _grad_fn_accepts_none_for_indices.get( + op_trace.op_type, None) + if none_indices is None or i not in none_indices: + out_gradients[i] = vspace.zeros( + *op_trace.output_shape_and_dtype[i]) + else: + out_gradients[i] = vspace.aggregate_fn(out_gradients[i]) + + in_gradients = op_trace.backward_function( + *(out_gradients + op_trace.side_outputs)) + for i, t in enumerate(op_trace.input_ids): + if in_gradients[i] is not None: + vspace.add_new_grads_fn(gradients, gradients_size, t, in_gradients[i]) + if tensor_usage_counts.get(t, 0) > 0: + tensor_usage_counts[t] -= 1 + if (t in tensor_to_op + and tensor_usage_counts[t] == 0 + and t not in id_sources): + in_op = tensor_to_op[t] + if in_op is None: + continue + if op_missing_tensor.get(in_op, 0) > 0: + op_missing_tensor[in_op] -= 1 + if op_missing_tensor.get(in_op, 0) == 0: + ready_ops.append(in_op) + result = [] + for i, s in enumerate(sources): + g = gradients.get(vspace.tensor_id(s), None) + if g is None: + result.append(None) + else: + result.append(vspace.aggregate_fn(g)) + return result + + +# TODO(agarwal): use an automatic mechanism for handling None arguments to +# gradient functions. +# Some gradient functions can accept None arguments for gradients. The following +# maps the operation name to the indices at which the corresponding gradient +# function can accept None values. +# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values +# during backprop. However the gradient function uses only the first of those +# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4], +# indicates that only the gradient corresponding to index 0 is used, and the +# gradient values at indices 1-4 are ignored (and hence can be None). The +# backprop algorithm can then leverage this by not constructing zeros to +# pass for those indices. +_grad_fn_accepts_none_for_indices = { + "SoftmaxCrossEntropyWithLogits": [1], + "FusedBatchNorm": [1, 2, 3, 4] +} diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 50aa070985b..ae842976908 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -675,7 +675,7 @@ class _EagerTensorBase(Tensor): if not context.in_graph_mode(): self_device = self.device def grad_fun(dresult): - return dresult._copy(device_name=self_device) + return [dresult._copy(device_name=self_device)] tape.record_operation("_copy", [new_tensor], [self], [], grad_fun) return new_tensor # pylint: enable=protected-access