Splits backprop.py in two files, one of which can be converted to C

PiperOrigin-RevId: 171165855
This commit is contained in:
Alexandre Passos 2017-10-05 09:50:49 -07:00 committed by TensorFlower Gardener
parent 7e7d55c0f5
commit 5f97262ae6
6 changed files with 335 additions and 286 deletions

View File

@ -339,7 +339,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":imperative_grad",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
@ -425,3 +427,9 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
py_library(
name = "imperative_grad",
srcs = ["imperative_grad.py"],
deps = [":tape"],
)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import operator
import threading
@ -28,6 +27,7 @@ import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -36,288 +36,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
# so as to release the gradient tensor to save memory.
_MIN_AGGREGATE_COUNT = 4
_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
# Terminology:
#
# - op: a possibly composite operation, which has an entry in the tape
# - target: dy in dx/dy
# - source: dx in dx/dy
# - tensor: one of the many inputs or outputs of an operation
#
# Below here we do the gradient algorithm. It works as follows:
#
# First we filter the tape to just the subset of operations we want to
# differentiate. In the process of doing so we count how many times each Tensor
# is used as an input to an op (so we know when we're done computing gradients
# for that Tensor). We also count, for each tape entry, how many of its output
# Tensors need gradients to be computed (Tensors which are not used do not need
# any gradients to be computed).
#
# Finally, we start a backprop stack with a set of tape entries for which we
# have all gradients available. This set usually is a subset of the set of
# targets (not all since targets which have outputs in the tape will not have
# gradients available initially).
#
# Then we repeatedly pop an entry from the stack, run its backprop, and update
# the gradients of its inputs. Once we have computed all gradients for a single
# input we can mark this input as done, and this can trigger adding an entry to
# the stack if all outputs of that entry are now done.
#
# When the stack is empty we have gradients for all tensors we're interested in.
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
"""Filters the tape to only include relevant entries and counts tensor usages.
Args:
target: the target to optimize.
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
op_to_entry: Map from op id to a tape.TapeEntry object
id_sources: the ids of the sources wrt the gradient is being taken.
Returns:
usage counts (how many entries downstream from a tensor use it)
op_to_entry_map: entry map (a filtered tape, with only the relevant
entries),
missing: map from tensor id to how many downstream gradients still need
to be computed before this tensor's gradient can be computed.
"""
if isinstance(target, (ops.Tensor)):
tensor_stack = [ops.tensor_id(target)]
else:
tensor_stack = list([ops.tensor_id(x) for x in target])
tensor_usage_counts = {}
o_to_e = {} # Copy of just the bits we need from op_to_entry
while tensor_stack:
t = tensor_stack.pop()
op = tensor_to_op.get(t, None)
# op is None if the tensor is a source (i.e. was watched directly)
if op is None or op in o_to_e:
continue
op_trace = op_to_entry[op]
o_to_e[op] = op_trace
for it in op_trace.input_ids:
if it in tensor_usage_counts:
tensor_usage_counts[it] += 1
else:
tensor_usage_counts[it] = 1
if it not in id_sources and it in tensor_to_op:
tensor_stack.append(it)
op_missing_tensor_counts = collections.defaultdict(int)
for t in tensor_usage_counts:
if t in tensor_to_op and tensor_to_op[t] is not None:
op_missing_tensor_counts[tensor_to_op[t]] += 1
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
"""Returns the set of tape entries which are available for backprop."""
ready_ops = []
for op in op_to_entry:
if op not in op_missing_tensor:
ready_ops.append(op)
return ready_ops
def _initial_gradients(target, output_gradients, tensor_usage_counts):
"""Computes the initial gradients for each Tensor."""
# Initialize the backprop stack
gradients = collections.defaultdict(list)
if isinstance(target, ops.Tensor):
if output_gradients is not None:
output_gradient = output_gradients
else:
output_gradient = array_ops.ones_like(target)
gradients[ops.tensor_id(target)].append(output_gradient)
else:
for i, t in enumerate(target):
if ops.tensor_id(t) in tensor_usage_counts:
# Can't provide a gradient of something we're trying to differentiate
assert output_gradients is None or output_gradients[i] is None
else:
if output_gradients is None or output_gradients[i] is None:
out_grad = array_ops.ones_like(t)
else:
out_grad = output_gradients[i]
gradients[ops.tensor_id(t)].append(out_grad)
return gradients
@tf_contextlib.contextmanager
def _no_op():
yield
def _aggregate_grads(gradients):
"""Aggregate gradients from multiple sources.
Args:
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
Returns:
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
Otherwise returns an aggregated 'IndexedSlices'.
"""
assert gradients, "No gradients to aggregate"
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
return math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
indexed_slices_list = []
for grad in gradients:
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
if isinstance(grad, ops.Tensor):
indexed_slices = ops.IndexedSlices(
grad,
constant_op.constant(range(grad.shape[0])),
constant_op.constant(grad.shape.as_list()))
indexed_slices_list.append(indexed_slices)
else:
indexed_slices_list.append(grad)
# Dense shapes from all gradients should be the same.
dense_shape = indexed_slices_list[0].dense_shape
# For simplicity now, always cast to int64.
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
for x in indexed_slices_list], 0)
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
return ops.IndexedSlices(values, indices, dense_shape)
def _add_new_grads(gradients, gradients_size, tid, grad):
"""Adds a new gradient and maybe aggregate the gradients.
Args:
gradients: A dict map from tensor id to list of gradients.
gradients_size: A dict map from tensor id to its total units. Might
not be initialized.
tid: Tensor id.
grad: New gradient for the `tid`, either a Tensor or IndexedSlices.
Raises:
ValueError: if `grad` is neight Tensor nor IndexedSlices.
"""
tensor_grads = gradients[tid]
tensor_grads.append(grad)
if len(tensor_grads) < _MIN_AGGREGATE_COUNT:
return
elif tid not in gradients_size:
if isinstance(grad, ops.Tensor):
size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
elif isinstance(grad, ops.IndexedSlices):
size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
else:
raise ValueError("Unexpected gradient type: %s" % type(grad))
gradients_size[tid] = size
else:
size = gradients_size[tid]
# For simplicity, assume each element to be 4 bytes now.
if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
gradients[tid] = [_aggregate_grads(tensor_grads)]
def imperative_grad(
target,
sources,
output_gradients=None):
"""Computes gradients from the imperatively defined tape on top of the stack.
Works by filtering the tape, computing how many downstream usages are of each
tensor and entry, and repeatedly applying backward functions until we have
gradients for all sources.
Args:
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
output_gradients: if not None, a list of gradient provided for each Target,
or None if we are to use the target's computed downstream gradient.
Returns:
the gradient wrt each of the sources.
Raises:
RuntimeError: if something goes wrong.
ValueError: if there is no sequence of differentiable operations connecting
a source and any target Tensor. This can happen either if the target is
not computed based on the source, if the tracing was set up incorrectly,
or if only non-differentiable functions of the source were used in the
computation of target.
"""
if not tape._tape_stack.stack: # pylint: disable=protected-access
raise RuntimeError("Computing a gradient with no tape present")
bp_tape = tape.pop_tape()
tensor_to_op, op_to_entry = bp_tape.export()
# This overwrites the op_to_entry variable, which will release all memory used
# to keep traces that are irrelevant to the gradient computation we're doing
# here.
id_sources = [ops.tensor_id(t) for t in sources]
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
target, tensor_to_op, op_to_entry, id_sources)
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
gradients = _initial_gradients(target, output_gradients,
tensor_usage_counts)
gradients_size = dict()
# Now exhaust the backprop stack
while ready_ops:
op = ready_ops.pop()
op_trace = op_to_entry.pop(op)
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
for i in range(len(out_gradients)):
if out_gradients[i] is None:
# TODO(apassos) this should be in the right device
none_indices = _grad_fn_accepts_none_for_indices.get(
op_trace.op_type, None)
if none_indices is None or i not in none_indices:
out_gradients[i] = array_ops.zeros(
*op_trace.output_shape_and_dtype[i])
else:
out_gradients[i] = _aggregate_grads(out_gradients[i])
in_gradients = op_trace.backward_function(
*(out_gradients + op_trace.side_outputs))
in_gradients = ([in_gradients]
if isinstance(in_gradients, (ops.Tensor,
ops.IndexedSlices,
type(None)))
else in_gradients)
for i, t in enumerate(op_trace.input_ids):
if in_gradients[i] is not None:
_add_new_grads(gradients, gradients_size, t, in_gradients[i])
if tensor_usage_counts.get(t, 0) > 0:
tensor_usage_counts[t] -= 1
if (t in tensor_to_op
and tensor_usage_counts[t] == 0
and t not in id_sources):
in_op = tensor_to_op[t]
if in_op is None:
continue
if op_missing_tensor.get(in_op, 0) > 0:
op_missing_tensor[in_op] -= 1
if op_missing_tensor.get(in_op, 0) == 0:
ready_ops.append(in_op)
result = []
for i, s in enumerate(sources):
g = gradients.get(ops.tensor_id(s), None)
if g is None:
result.append(None)
else:
result.append(_aggregate_grads(g))
return result
_op_attr_type_cache = {}
@ -557,7 +279,7 @@ def _record_gradient(op_name, inputs, attrs, results, name):
if _tracing:
print("Gradient for", (name if name else op_name), "inputs", op_inputs,
"output_grads", orig_outputs, "gradients", result)
return result
return nest.flatten(result)
tape.record_operation(op_name, results, inputs, [], grad_fn)
if _tracing:
@ -615,7 +337,9 @@ def implicit_val_and_grad(f):
end_node = f(*args)
variables = tape.top_tape_watched_variables()
sources = [x.handle for x in variables]
grad = imperative_grad(end_node, sources)
grad = imperative_grad.imperative_grad(_default_vspace,
nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
return grad_fn
@ -849,6 +573,96 @@ def val_and_grad_function(f, params=None):
sources.append(args[i])
tape.watch(args[i])
result = f(*args)
return result, imperative_grad(result, sources, output_gradients=dy)
return result, imperative_grad.imperative_grad(
_default_vspace, nest.flatten(result), sources,
output_gradients=nest.flatten(dy) if dy is not None else None)
return decorated
def _aggregate_grads(gradients):
"""Aggregate gradients from multiple sources.
Args:
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
Returns:
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
Otherwise returns an aggregated 'IndexedSlices'.
"""
assert gradients, "No gradients to aggregate"
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
return math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
indexed_slices_list = []
for grad in gradients:
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
if isinstance(grad, ops.Tensor):
indexed_slices = ops.IndexedSlices(
grad,
constant_op.constant(range(grad.shape[0])),
constant_op.constant(grad.shape.as_list()))
indexed_slices_list.append(indexed_slices)
else:
indexed_slices_list.append(grad)
# Dense shapes from all gradients should be the same.
dense_shape = indexed_slices_list[0].dense_shape
# For simplicity now, always cast to int64.
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
for x in indexed_slices_list], 0)
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
return ops.IndexedSlices(values, indices, dense_shape)
# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
# so as to release the gradient tensor to save memory.
_MIN_AGGREGATE_COUNT = 4
_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
def _add_new_grads(gradients, gradients_size, tid, grad):
"""Adds a new gradient and maybe aggregate the gradients.
Args:
gradients: A dict map from tensor id to list of gradients.
gradients_size: A dict map from tensor id to its total units. Might
not be initialized.
tid: Tensor id.
grad: New gradient for the `tid`, either a Tensor or IndexedSlices.
Raises:
ValueError: if `grad` is neight Tensor nor IndexedSlices.
"""
tensor_grads = gradients[tid]
tensor_grads.append(grad)
if len(tensor_grads) < _MIN_AGGREGATE_COUNT:
return
elif tid not in gradients_size:
if isinstance(grad, ops.Tensor):
size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
elif isinstance(grad, ops.IndexedSlices):
size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
else:
raise ValueError("Unexpected gradient type: %s" % type(grad))
gradients_size[tid] = size
else:
size = gradients_size[tid]
# For simplicity, assume each element to be 4 bytes now.
if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
gradients[tid] = [_aggregate_grads(tensor_grads)]
_default_vspace = imperative_grad.VSpace(
add_new_grads_fn=_add_new_grads,
aggregate_fn=_aggregate_grads,
tensor_id=ops.tensor_id,
zeros=array_ops.zeros,
ones_like=array_ops.ones_like)

View File

@ -78,7 +78,7 @@ def custom_gradient(f):
# second derivative this way if they capture any output tensors. Change the
# signature of custom_gradient.
def actual_grad_fn(*outputs):
return grad_fn(*outputs)
return nest.flatten(grad_fn(*outputs))
flat_result = nest.flatten(result)
tape.record_operation(

View File

@ -88,7 +88,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
else:
captured_value = captured_value[1]
tape.record_operation("captured_value", [captured_value], [value], [],
lambda x: x)
lambda x: [x])
return captured_value

View File

@ -0,0 +1,227 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for backpropagation using the tape utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.eager import tape
# Terminology:
#
# - op: a possibly composite operation, which has an entry in the tape
# - target: dy in dx/dy
# - source: dx in dx/dy
# - tensor: one of the many inputs or outputs of an operation
#
# Below here we do the gradient algorithm. It works as follows:
#
# First we filter the tape to just the subset of operations we want to
# differentiate. In the process of doing so we count how many times each Tensor
# is used as an input to an op (so we know when we're done computing gradients
# for that Tensor). We also count, for each tape entry, how many of its output
# Tensors need gradients to be computed (Tensors which are not used do not need
# any gradients to be computed).
#
# Finally, we start a backprop stack with a set of tape entries for which we
# have all gradients available. This set usually is a subset of the set of
# targets (not all since targets which have outputs in the tape will not have
# gradients available initially).
#
# Then we repeatedly pop an entry from the stack, run its backprop, and update
# the gradients of its inputs. Once we have computed all gradients for a single
# input we can mark this input as done, and this can trigger adding an entry to
# the stack if all outputs of that entry are now done.
#
# When the stack is empty we have gradients for all tensors we're interested in.
def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
"""Filters the tape to only include relevant entries and counts tensor usages.
Args:
vspace: information about the space we're differentiating in.
target: the target to optimize.
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
op_to_entry: Map from op id to a tape.TapeEntry object
id_sources: the ids of the sources wrt the gradient is being taken.
Returns:
usage counts (how many entries downstream from a tensor use it)
op_to_entry_map: entry map (a filtered tape, with only the relevant
entries),
missing: map from tensor id to how many downstream gradients still need
to be computed before this tensor's gradient can be computed.
"""
tensor_stack = [vspace.tensor_id(x) for x in target]
tensor_usage_counts = {}
o_to_e = {} # Copy of just the bits we need from op_to_entry
while tensor_stack:
t = tensor_stack.pop()
op = tensor_to_op.get(t, None)
# op is None if the tensor is a source (i.e. was watched directly)
if op is None or op in o_to_e:
continue
op_trace = op_to_entry[op]
o_to_e[op] = op_trace
for it in op_trace.input_ids:
if it in tensor_usage_counts:
tensor_usage_counts[it] += 1
else:
tensor_usage_counts[it] = 1
if it not in id_sources and it in tensor_to_op:
tensor_stack.append(it)
op_missing_tensor_counts = collections.defaultdict(int)
for t in tensor_usage_counts:
if t in tensor_to_op and tensor_to_op[t] is not None:
op_missing_tensor_counts[tensor_to_op[t]] += 1
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
"""Returns the set of tape entries which are available for backprop."""
ready_ops = []
for op in op_to_entry:
if op not in op_missing_tensor:
ready_ops.append(op)
return ready_ops
def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
"""Computes the initial gradients for each Tensor."""
# Initialize the backprop stack
gradients = collections.defaultdict(list)
for i, t in enumerate(target):
if vspace.tensor_id(t) in tensor_usage_counts:
# Can't provide a gradient of something we're trying to differentiate
assert output_gradients is None or output_gradients[i] is None
else:
if output_gradients is None or output_gradients[i] is None:
out_grad = vspace.ones_like(t)
else:
out_grad = output_gradients[i]
gradients[vspace.tensor_id(t)].append(out_grad)
return gradients
VSpace = collections.namedtuple(
"VSpace",
["add_new_grads_fn", "aggregate_fn", "tensor_id", "zeros", "ones_like"])
def imperative_grad(
vspace,
target,
sources,
output_gradients=None):
"""Computes gradients from the imperatively defined tape on top of the stack.
Works by filtering the tape, computing how many downstream usages are of each
tensor and entry, and repeatedly applying backward functions until we have
gradients for all sources.
Args:
vspace: the vector space in which to differentiate.
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
output_gradients: if not None, a list of gradient provided for each Target,
or None if we are to use the target's computed downstream gradient.
Returns:
the gradient wrt each of the sources.
Raises:
RuntimeError: if something goes wrong.
ValueError: if there is no sequence of differentiable operations connecting
a source and any target Tensor. This can happen either if the target is
not computed based on the source, if the tracing was set up incorrectly,
or if only non-differentiable functions of the source were used in the
computation of target.
"""
if not tape._tape_stack.stack: # pylint: disable=protected-access
raise RuntimeError("Computing a gradient with no tape present")
bp_tape = tape.pop_tape()
tensor_to_op, op_to_entry = bp_tape.export()
# This overwrites the op_to_entry variable, which will release all memory used
# to keep traces that are irrelevant to the gradient computation we're doing
# here.
id_sources = [vspace.tensor_id(t) for t in sources]
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
vspace, target, tensor_to_op, op_to_entry, id_sources)
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
gradients = _initial_gradients(vspace, target, output_gradients,
tensor_usage_counts)
gradients_size = dict()
# Now exhaust the backprop stack
while ready_ops:
op = ready_ops.pop()
op_trace = op_to_entry.pop(op)
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
for i in range(len(out_gradients)):
if out_gradients[i] is None:
# TODO(apassos) this should be in the right device
none_indices = _grad_fn_accepts_none_for_indices.get(
op_trace.op_type, None)
if none_indices is None or i not in none_indices:
out_gradients[i] = vspace.zeros(
*op_trace.output_shape_and_dtype[i])
else:
out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
in_gradients = op_trace.backward_function(
*(out_gradients + op_trace.side_outputs))
for i, t in enumerate(op_trace.input_ids):
if in_gradients[i] is not None:
vspace.add_new_grads_fn(gradients, gradients_size, t, in_gradients[i])
if tensor_usage_counts.get(t, 0) > 0:
tensor_usage_counts[t] -= 1
if (t in tensor_to_op
and tensor_usage_counts[t] == 0
and t not in id_sources):
in_op = tensor_to_op[t]
if in_op is None:
continue
if op_missing_tensor.get(in_op, 0) > 0:
op_missing_tensor[in_op] -= 1
if op_missing_tensor.get(in_op, 0) == 0:
ready_ops.append(in_op)
result = []
for i, s in enumerate(sources):
g = gradients.get(vspace.tensor_id(s), None)
if g is None:
result.append(None)
else:
result.append(vspace.aggregate_fn(g))
return result
# TODO(agarwal): use an automatic mechanism for handling None arguments to
# gradient functions.
# Some gradient functions can accept None arguments for gradients. The following
# maps the operation name to the indices at which the corresponding gradient
# function can accept None values.
# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
# during backprop. However the gradient function uses only the first of those
# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
# indicates that only the gradient corresponding to index 0 is used, and the
# gradient values at indices 1-4 are ignored (and hence can be None). The
# backprop algorithm can then leverage this by not constructing zeros to
# pass for those indices.
_grad_fn_accepts_none_for_indices = {
"SoftmaxCrossEntropyWithLogits": [1],
"FusedBatchNorm": [1, 2, 3, 4]
}

View File

@ -675,7 +675,7 @@ class _EagerTensorBase(Tensor):
if not context.in_graph_mode():
self_device = self.device
def grad_fun(dresult):
return dresult._copy(device_name=self_device)
return [dresult._copy(device_name=self_device)]
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
return new_tensor
# pylint: enable=protected-access