Splits backprop.py in two files, one of which can be converted to C
PiperOrigin-RevId: 171165855
This commit is contained in:
parent
7e7d55c0f5
commit
5f97262ae6
@ -339,7 +339,9 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":imperative_grad",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
@ -425,3 +427,9 @@ filegroup(
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "imperative_grad",
|
||||
srcs = ["imperative_grad.py"],
|
||||
deps = [":tape"],
|
||||
)
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
import threading
|
||||
@ -28,6 +27,7 @@ import six
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import imperative_grad
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -36,288 +36,10 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
|
||||
# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
|
||||
# so as to release the gradient tensor to save memory.
|
||||
_MIN_AGGREGATE_COUNT = 4
|
||||
_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
|
||||
|
||||
# Terminology:
|
||||
#
|
||||
# - op: a possibly composite operation, which has an entry in the tape
|
||||
# - target: dy in dx/dy
|
||||
# - source: dx in dx/dy
|
||||
# - tensor: one of the many inputs or outputs of an operation
|
||||
#
|
||||
# Below here we do the gradient algorithm. It works as follows:
|
||||
#
|
||||
# First we filter the tape to just the subset of operations we want to
|
||||
# differentiate. In the process of doing so we count how many times each Tensor
|
||||
# is used as an input to an op (so we know when we're done computing gradients
|
||||
# for that Tensor). We also count, for each tape entry, how many of its output
|
||||
# Tensors need gradients to be computed (Tensors which are not used do not need
|
||||
# any gradients to be computed).
|
||||
#
|
||||
# Finally, we start a backprop stack with a set of tape entries for which we
|
||||
# have all gradients available. This set usually is a subset of the set of
|
||||
# targets (not all since targets which have outputs in the tape will not have
|
||||
# gradients available initially).
|
||||
#
|
||||
# Then we repeatedly pop an entry from the stack, run its backprop, and update
|
||||
# the gradients of its inputs. Once we have computed all gradients for a single
|
||||
# input we can mark this input as done, and this can trigger adding an entry to
|
||||
# the stack if all outputs of that entry are now done.
|
||||
#
|
||||
# When the stack is empty we have gradients for all tensors we're interested in.
|
||||
|
||||
|
||||
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
|
||||
"""Filters the tape to only include relevant entries and counts tensor usages.
|
||||
|
||||
Args:
|
||||
target: the target to optimize.
|
||||
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
|
||||
op_to_entry: Map from op id to a tape.TapeEntry object
|
||||
id_sources: the ids of the sources wrt the gradient is being taken.
|
||||
|
||||
Returns:
|
||||
usage counts (how many entries downstream from a tensor use it)
|
||||
op_to_entry_map: entry map (a filtered tape, with only the relevant
|
||||
entries),
|
||||
missing: map from tensor id to how many downstream gradients still need
|
||||
to be computed before this tensor's gradient can be computed.
|
||||
"""
|
||||
if isinstance(target, (ops.Tensor)):
|
||||
tensor_stack = [ops.tensor_id(target)]
|
||||
else:
|
||||
tensor_stack = list([ops.tensor_id(x) for x in target])
|
||||
tensor_usage_counts = {}
|
||||
o_to_e = {} # Copy of just the bits we need from op_to_entry
|
||||
while tensor_stack:
|
||||
t = tensor_stack.pop()
|
||||
op = tensor_to_op.get(t, None)
|
||||
# op is None if the tensor is a source (i.e. was watched directly)
|
||||
if op is None or op in o_to_e:
|
||||
continue
|
||||
op_trace = op_to_entry[op]
|
||||
o_to_e[op] = op_trace
|
||||
for it in op_trace.input_ids:
|
||||
if it in tensor_usage_counts:
|
||||
tensor_usage_counts[it] += 1
|
||||
else:
|
||||
tensor_usage_counts[it] = 1
|
||||
if it not in id_sources and it in tensor_to_op:
|
||||
tensor_stack.append(it)
|
||||
op_missing_tensor_counts = collections.defaultdict(int)
|
||||
for t in tensor_usage_counts:
|
||||
if t in tensor_to_op and tensor_to_op[t] is not None:
|
||||
op_missing_tensor_counts[tensor_to_op[t]] += 1
|
||||
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
|
||||
|
||||
|
||||
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
|
||||
"""Returns the set of tape entries which are available for backprop."""
|
||||
ready_ops = []
|
||||
for op in op_to_entry:
|
||||
if op not in op_missing_tensor:
|
||||
ready_ops.append(op)
|
||||
return ready_ops
|
||||
|
||||
|
||||
def _initial_gradients(target, output_gradients, tensor_usage_counts):
|
||||
"""Computes the initial gradients for each Tensor."""
|
||||
# Initialize the backprop stack
|
||||
gradients = collections.defaultdict(list)
|
||||
if isinstance(target, ops.Tensor):
|
||||
if output_gradients is not None:
|
||||
output_gradient = output_gradients
|
||||
else:
|
||||
output_gradient = array_ops.ones_like(target)
|
||||
gradients[ops.tensor_id(target)].append(output_gradient)
|
||||
else:
|
||||
for i, t in enumerate(target):
|
||||
if ops.tensor_id(t) in tensor_usage_counts:
|
||||
# Can't provide a gradient of something we're trying to differentiate
|
||||
assert output_gradients is None or output_gradients[i] is None
|
||||
else:
|
||||
if output_gradients is None or output_gradients[i] is None:
|
||||
out_grad = array_ops.ones_like(t)
|
||||
else:
|
||||
out_grad = output_gradients[i]
|
||||
gradients[ops.tensor_id(t)].append(out_grad)
|
||||
return gradients
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def _no_op():
|
||||
yield
|
||||
|
||||
|
||||
def _aggregate_grads(gradients):
|
||||
"""Aggregate gradients from multiple sources.
|
||||
|
||||
Args:
|
||||
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
|
||||
|
||||
Returns:
|
||||
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
|
||||
Otherwise returns an aggregated 'IndexedSlices'.
|
||||
"""
|
||||
assert gradients, "No gradients to aggregate"
|
||||
|
||||
if len(gradients) == 1:
|
||||
return gradients[0]
|
||||
if all([isinstance(g, ops.Tensor) for g in gradients]):
|
||||
return math_ops.add_n(gradients)
|
||||
else:
|
||||
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
|
||||
for g in gradients])
|
||||
indexed_slices_list = []
|
||||
for grad in gradients:
|
||||
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
|
||||
if isinstance(grad, ops.Tensor):
|
||||
indexed_slices = ops.IndexedSlices(
|
||||
grad,
|
||||
constant_op.constant(range(grad.shape[0])),
|
||||
constant_op.constant(grad.shape.as_list()))
|
||||
indexed_slices_list.append(indexed_slices)
|
||||
else:
|
||||
indexed_slices_list.append(grad)
|
||||
|
||||
# Dense shapes from all gradients should be the same.
|
||||
dense_shape = indexed_slices_list[0].dense_shape
|
||||
# For simplicity now, always cast to int64.
|
||||
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
|
||||
for x in indexed_slices_list], 0)
|
||||
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
|
||||
|
||||
def _add_new_grads(gradients, gradients_size, tid, grad):
|
||||
"""Adds a new gradient and maybe aggregate the gradients.
|
||||
|
||||
Args:
|
||||
gradients: A dict map from tensor id to list of gradients.
|
||||
gradients_size: A dict map from tensor id to its total units. Might
|
||||
not be initialized.
|
||||
tid: Tensor id.
|
||||
grad: New gradient for the `tid`, either a Tensor or IndexedSlices.
|
||||
|
||||
Raises:
|
||||
ValueError: if `grad` is neight Tensor nor IndexedSlices.
|
||||
"""
|
||||
tensor_grads = gradients[tid]
|
||||
tensor_grads.append(grad)
|
||||
if len(tensor_grads) < _MIN_AGGREGATE_COUNT:
|
||||
return
|
||||
elif tid not in gradients_size:
|
||||
if isinstance(grad, ops.Tensor):
|
||||
size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
|
||||
elif isinstance(grad, ops.IndexedSlices):
|
||||
size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError("Unexpected gradient type: %s" % type(grad))
|
||||
gradients_size[tid] = size
|
||||
else:
|
||||
size = gradients_size[tid]
|
||||
|
||||
# For simplicity, assume each element to be 4 bytes now.
|
||||
if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
|
||||
gradients[tid] = [_aggregate_grads(tensor_grads)]
|
||||
|
||||
|
||||
def imperative_grad(
|
||||
target,
|
||||
sources,
|
||||
output_gradients=None):
|
||||
"""Computes gradients from the imperatively defined tape on top of the stack.
|
||||
|
||||
Works by filtering the tape, computing how many downstream usages are of each
|
||||
tensor and entry, and repeatedly applying backward functions until we have
|
||||
gradients for all sources.
|
||||
|
||||
Args:
|
||||
target: either a Tensor or list of Tensors to be differentiated.
|
||||
sources: list of Tensors for which we want gradients
|
||||
output_gradients: if not None, a list of gradient provided for each Target,
|
||||
or None if we are to use the target's computed downstream gradient.
|
||||
|
||||
Returns:
|
||||
the gradient wrt each of the sources.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if something goes wrong.
|
||||
ValueError: if there is no sequence of differentiable operations connecting
|
||||
a source and any target Tensor. This can happen either if the target is
|
||||
not computed based on the source, if the tracing was set up incorrectly,
|
||||
or if only non-differentiable functions of the source were used in the
|
||||
computation of target.
|
||||
"""
|
||||
if not tape._tape_stack.stack: # pylint: disable=protected-access
|
||||
raise RuntimeError("Computing a gradient with no tape present")
|
||||
bp_tape = tape.pop_tape()
|
||||
tensor_to_op, op_to_entry = bp_tape.export()
|
||||
# This overwrites the op_to_entry variable, which will release all memory used
|
||||
# to keep traces that are irrelevant to the gradient computation we're doing
|
||||
# here.
|
||||
id_sources = [ops.tensor_id(t) for t in sources]
|
||||
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
|
||||
target, tensor_to_op, op_to_entry, id_sources)
|
||||
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
|
||||
gradients = _initial_gradients(target, output_gradients,
|
||||
tensor_usage_counts)
|
||||
gradients_size = dict()
|
||||
# Now exhaust the backprop stack
|
||||
while ready_ops:
|
||||
op = ready_ops.pop()
|
||||
op_trace = op_to_entry.pop(op)
|
||||
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
|
||||
for i in range(len(out_gradients)):
|
||||
if out_gradients[i] is None:
|
||||
# TODO(apassos) this should be in the right device
|
||||
none_indices = _grad_fn_accepts_none_for_indices.get(
|
||||
op_trace.op_type, None)
|
||||
if none_indices is None or i not in none_indices:
|
||||
out_gradients[i] = array_ops.zeros(
|
||||
*op_trace.output_shape_and_dtype[i])
|
||||
else:
|
||||
out_gradients[i] = _aggregate_grads(out_gradients[i])
|
||||
|
||||
in_gradients = op_trace.backward_function(
|
||||
*(out_gradients + op_trace.side_outputs))
|
||||
in_gradients = ([in_gradients]
|
||||
if isinstance(in_gradients, (ops.Tensor,
|
||||
ops.IndexedSlices,
|
||||
type(None)))
|
||||
else in_gradients)
|
||||
for i, t in enumerate(op_trace.input_ids):
|
||||
if in_gradients[i] is not None:
|
||||
_add_new_grads(gradients, gradients_size, t, in_gradients[i])
|
||||
if tensor_usage_counts.get(t, 0) > 0:
|
||||
tensor_usage_counts[t] -= 1
|
||||
if (t in tensor_to_op
|
||||
and tensor_usage_counts[t] == 0
|
||||
and t not in id_sources):
|
||||
in_op = tensor_to_op[t]
|
||||
if in_op is None:
|
||||
continue
|
||||
if op_missing_tensor.get(in_op, 0) > 0:
|
||||
op_missing_tensor[in_op] -= 1
|
||||
if op_missing_tensor.get(in_op, 0) == 0:
|
||||
ready_ops.append(in_op)
|
||||
result = []
|
||||
for i, s in enumerate(sources):
|
||||
g = gradients.get(ops.tensor_id(s), None)
|
||||
if g is None:
|
||||
result.append(None)
|
||||
else:
|
||||
result.append(_aggregate_grads(g))
|
||||
return result
|
||||
|
||||
_op_attr_type_cache = {}
|
||||
|
||||
|
||||
@ -557,7 +279,7 @@ def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
if _tracing:
|
||||
print("Gradient for", (name if name else op_name), "inputs", op_inputs,
|
||||
"output_grads", orig_outputs, "gradients", result)
|
||||
return result
|
||||
return nest.flatten(result)
|
||||
|
||||
tape.record_operation(op_name, results, inputs, [], grad_fn)
|
||||
if _tracing:
|
||||
@ -615,7 +337,9 @@ def implicit_val_and_grad(f):
|
||||
end_node = f(*args)
|
||||
variables = tape.top_tape_watched_variables()
|
||||
sources = [x.handle for x in variables]
|
||||
grad = imperative_grad(end_node, sources)
|
||||
grad = imperative_grad.imperative_grad(_default_vspace,
|
||||
nest.flatten(end_node),
|
||||
sources)
|
||||
return end_node, list(zip(grad, variables))
|
||||
|
||||
return grad_fn
|
||||
@ -849,6 +573,96 @@ def val_and_grad_function(f, params=None):
|
||||
sources.append(args[i])
|
||||
tape.watch(args[i])
|
||||
result = f(*args)
|
||||
return result, imperative_grad(result, sources, output_gradients=dy)
|
||||
return result, imperative_grad.imperative_grad(
|
||||
_default_vspace, nest.flatten(result), sources,
|
||||
output_gradients=nest.flatten(dy) if dy is not None else None)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def _aggregate_grads(gradients):
|
||||
"""Aggregate gradients from multiple sources.
|
||||
|
||||
Args:
|
||||
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
|
||||
|
||||
Returns:
|
||||
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
|
||||
Otherwise returns an aggregated 'IndexedSlices'.
|
||||
"""
|
||||
assert gradients, "No gradients to aggregate"
|
||||
|
||||
if len(gradients) == 1:
|
||||
return gradients[0]
|
||||
if all([isinstance(g, ops.Tensor) for g in gradients]):
|
||||
return math_ops.add_n(gradients)
|
||||
else:
|
||||
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
|
||||
for g in gradients])
|
||||
indexed_slices_list = []
|
||||
for grad in gradients:
|
||||
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
|
||||
if isinstance(grad, ops.Tensor):
|
||||
indexed_slices = ops.IndexedSlices(
|
||||
grad,
|
||||
constant_op.constant(range(grad.shape[0])),
|
||||
constant_op.constant(grad.shape.as_list()))
|
||||
indexed_slices_list.append(indexed_slices)
|
||||
else:
|
||||
indexed_slices_list.append(grad)
|
||||
|
||||
# Dense shapes from all gradients should be the same.
|
||||
dense_shape = indexed_slices_list[0].dense_shape
|
||||
# For simplicity now, always cast to int64.
|
||||
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
|
||||
for x in indexed_slices_list], 0)
|
||||
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
|
||||
|
||||
# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
|
||||
# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
|
||||
# so as to release the gradient tensor to save memory.
|
||||
_MIN_AGGREGATE_COUNT = 4
|
||||
_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
|
||||
|
||||
|
||||
def _add_new_grads(gradients, gradients_size, tid, grad):
|
||||
"""Adds a new gradient and maybe aggregate the gradients.
|
||||
|
||||
Args:
|
||||
gradients: A dict map from tensor id to list of gradients.
|
||||
gradients_size: A dict map from tensor id to its total units. Might
|
||||
not be initialized.
|
||||
tid: Tensor id.
|
||||
grad: New gradient for the `tid`, either a Tensor or IndexedSlices.
|
||||
|
||||
Raises:
|
||||
ValueError: if `grad` is neight Tensor nor IndexedSlices.
|
||||
"""
|
||||
tensor_grads = gradients[tid]
|
||||
tensor_grads.append(grad)
|
||||
if len(tensor_grads) < _MIN_AGGREGATE_COUNT:
|
||||
return
|
||||
elif tid not in gradients_size:
|
||||
if isinstance(grad, ops.Tensor):
|
||||
size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
|
||||
elif isinstance(grad, ops.IndexedSlices):
|
||||
size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError("Unexpected gradient type: %s" % type(grad))
|
||||
gradients_size[tid] = size
|
||||
else:
|
||||
size = gradients_size[tid]
|
||||
|
||||
# For simplicity, assume each element to be 4 bytes now.
|
||||
if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
|
||||
gradients[tid] = [_aggregate_grads(tensor_grads)]
|
||||
|
||||
|
||||
_default_vspace = imperative_grad.VSpace(
|
||||
add_new_grads_fn=_add_new_grads,
|
||||
aggregate_fn=_aggregate_grads,
|
||||
tensor_id=ops.tensor_id,
|
||||
zeros=array_ops.zeros,
|
||||
ones_like=array_ops.ones_like)
|
||||
|
@ -78,7 +78,7 @@ def custom_gradient(f):
|
||||
# second derivative this way if they capture any output tensors. Change the
|
||||
# signature of custom_gradient.
|
||||
def actual_grad_fn(*outputs):
|
||||
return grad_fn(*outputs)
|
||||
return nest.flatten(grad_fn(*outputs))
|
||||
|
||||
flat_result = nest.flatten(result)
|
||||
tape.record_operation(
|
||||
|
@ -88,7 +88,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
||||
else:
|
||||
captured_value = captured_value[1]
|
||||
tape.record_operation("captured_value", [captured_value], [value], [],
|
||||
lambda x: x)
|
||||
lambda x: [x])
|
||||
return captured_value
|
||||
|
||||
|
||||
|
227
tensorflow/python/eager/imperative_grad.py
Normal file
227
tensorflow/python/eager/imperative_grad.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Code for backpropagation using the tape utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import tape
|
||||
|
||||
|
||||
# Terminology:
|
||||
#
|
||||
# - op: a possibly composite operation, which has an entry in the tape
|
||||
# - target: dy in dx/dy
|
||||
# - source: dx in dx/dy
|
||||
# - tensor: one of the many inputs or outputs of an operation
|
||||
#
|
||||
# Below here we do the gradient algorithm. It works as follows:
|
||||
#
|
||||
# First we filter the tape to just the subset of operations we want to
|
||||
# differentiate. In the process of doing so we count how many times each Tensor
|
||||
# is used as an input to an op (so we know when we're done computing gradients
|
||||
# for that Tensor). We also count, for each tape entry, how many of its output
|
||||
# Tensors need gradients to be computed (Tensors which are not used do not need
|
||||
# any gradients to be computed).
|
||||
#
|
||||
# Finally, we start a backprop stack with a set of tape entries for which we
|
||||
# have all gradients available. This set usually is a subset of the set of
|
||||
# targets (not all since targets which have outputs in the tape will not have
|
||||
# gradients available initially).
|
||||
#
|
||||
# Then we repeatedly pop an entry from the stack, run its backprop, and update
|
||||
# the gradients of its inputs. Once we have computed all gradients for a single
|
||||
# input we can mark this input as done, and this can trigger adding an entry to
|
||||
# the stack if all outputs of that entry are now done.
|
||||
#
|
||||
# When the stack is empty we have gradients for all tensors we're interested in.
|
||||
def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
|
||||
"""Filters the tape to only include relevant entries and counts tensor usages.
|
||||
|
||||
Args:
|
||||
vspace: information about the space we're differentiating in.
|
||||
target: the target to optimize.
|
||||
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
|
||||
op_to_entry: Map from op id to a tape.TapeEntry object
|
||||
id_sources: the ids of the sources wrt the gradient is being taken.
|
||||
|
||||
Returns:
|
||||
usage counts (how many entries downstream from a tensor use it)
|
||||
op_to_entry_map: entry map (a filtered tape, with only the relevant
|
||||
entries),
|
||||
missing: map from tensor id to how many downstream gradients still need
|
||||
to be computed before this tensor's gradient can be computed.
|
||||
"""
|
||||
tensor_stack = [vspace.tensor_id(x) for x in target]
|
||||
tensor_usage_counts = {}
|
||||
o_to_e = {} # Copy of just the bits we need from op_to_entry
|
||||
while tensor_stack:
|
||||
t = tensor_stack.pop()
|
||||
op = tensor_to_op.get(t, None)
|
||||
# op is None if the tensor is a source (i.e. was watched directly)
|
||||
if op is None or op in o_to_e:
|
||||
continue
|
||||
op_trace = op_to_entry[op]
|
||||
o_to_e[op] = op_trace
|
||||
for it in op_trace.input_ids:
|
||||
if it in tensor_usage_counts:
|
||||
tensor_usage_counts[it] += 1
|
||||
else:
|
||||
tensor_usage_counts[it] = 1
|
||||
if it not in id_sources and it in tensor_to_op:
|
||||
tensor_stack.append(it)
|
||||
op_missing_tensor_counts = collections.defaultdict(int)
|
||||
for t in tensor_usage_counts:
|
||||
if t in tensor_to_op and tensor_to_op[t] is not None:
|
||||
op_missing_tensor_counts[tensor_to_op[t]] += 1
|
||||
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
|
||||
|
||||
|
||||
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
|
||||
"""Returns the set of tape entries which are available for backprop."""
|
||||
ready_ops = []
|
||||
for op in op_to_entry:
|
||||
if op not in op_missing_tensor:
|
||||
ready_ops.append(op)
|
||||
return ready_ops
|
||||
|
||||
|
||||
def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
|
||||
"""Computes the initial gradients for each Tensor."""
|
||||
# Initialize the backprop stack
|
||||
gradients = collections.defaultdict(list)
|
||||
for i, t in enumerate(target):
|
||||
if vspace.tensor_id(t) in tensor_usage_counts:
|
||||
# Can't provide a gradient of something we're trying to differentiate
|
||||
assert output_gradients is None or output_gradients[i] is None
|
||||
else:
|
||||
if output_gradients is None or output_gradients[i] is None:
|
||||
out_grad = vspace.ones_like(t)
|
||||
else:
|
||||
out_grad = output_gradients[i]
|
||||
gradients[vspace.tensor_id(t)].append(out_grad)
|
||||
return gradients
|
||||
|
||||
|
||||
VSpace = collections.namedtuple(
|
||||
"VSpace",
|
||||
["add_new_grads_fn", "aggregate_fn", "tensor_id", "zeros", "ones_like"])
|
||||
|
||||
|
||||
def imperative_grad(
|
||||
vspace,
|
||||
target,
|
||||
sources,
|
||||
output_gradients=None):
|
||||
"""Computes gradients from the imperatively defined tape on top of the stack.
|
||||
|
||||
Works by filtering the tape, computing how many downstream usages are of each
|
||||
tensor and entry, and repeatedly applying backward functions until we have
|
||||
gradients for all sources.
|
||||
|
||||
Args:
|
||||
vspace: the vector space in which to differentiate.
|
||||
target: either a Tensor or list of Tensors to be differentiated.
|
||||
sources: list of Tensors for which we want gradients
|
||||
output_gradients: if not None, a list of gradient provided for each Target,
|
||||
or None if we are to use the target's computed downstream gradient.
|
||||
|
||||
Returns:
|
||||
the gradient wrt each of the sources.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if something goes wrong.
|
||||
ValueError: if there is no sequence of differentiable operations connecting
|
||||
a source and any target Tensor. This can happen either if the target is
|
||||
not computed based on the source, if the tracing was set up incorrectly,
|
||||
or if only non-differentiable functions of the source were used in the
|
||||
computation of target.
|
||||
"""
|
||||
if not tape._tape_stack.stack: # pylint: disable=protected-access
|
||||
raise RuntimeError("Computing a gradient with no tape present")
|
||||
bp_tape = tape.pop_tape()
|
||||
tensor_to_op, op_to_entry = bp_tape.export()
|
||||
# This overwrites the op_to_entry variable, which will release all memory used
|
||||
# to keep traces that are irrelevant to the gradient computation we're doing
|
||||
# here.
|
||||
id_sources = [vspace.tensor_id(t) for t in sources]
|
||||
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
|
||||
vspace, target, tensor_to_op, op_to_entry, id_sources)
|
||||
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
|
||||
gradients = _initial_gradients(vspace, target, output_gradients,
|
||||
tensor_usage_counts)
|
||||
gradients_size = dict()
|
||||
# Now exhaust the backprop stack
|
||||
while ready_ops:
|
||||
op = ready_ops.pop()
|
||||
op_trace = op_to_entry.pop(op)
|
||||
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
|
||||
for i in range(len(out_gradients)):
|
||||
if out_gradients[i] is None:
|
||||
# TODO(apassos) this should be in the right device
|
||||
none_indices = _grad_fn_accepts_none_for_indices.get(
|
||||
op_trace.op_type, None)
|
||||
if none_indices is None or i not in none_indices:
|
||||
out_gradients[i] = vspace.zeros(
|
||||
*op_trace.output_shape_and_dtype[i])
|
||||
else:
|
||||
out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
|
||||
|
||||
in_gradients = op_trace.backward_function(
|
||||
*(out_gradients + op_trace.side_outputs))
|
||||
for i, t in enumerate(op_trace.input_ids):
|
||||
if in_gradients[i] is not None:
|
||||
vspace.add_new_grads_fn(gradients, gradients_size, t, in_gradients[i])
|
||||
if tensor_usage_counts.get(t, 0) > 0:
|
||||
tensor_usage_counts[t] -= 1
|
||||
if (t in tensor_to_op
|
||||
and tensor_usage_counts[t] == 0
|
||||
and t not in id_sources):
|
||||
in_op = tensor_to_op[t]
|
||||
if in_op is None:
|
||||
continue
|
||||
if op_missing_tensor.get(in_op, 0) > 0:
|
||||
op_missing_tensor[in_op] -= 1
|
||||
if op_missing_tensor.get(in_op, 0) == 0:
|
||||
ready_ops.append(in_op)
|
||||
result = []
|
||||
for i, s in enumerate(sources):
|
||||
g = gradients.get(vspace.tensor_id(s), None)
|
||||
if g is None:
|
||||
result.append(None)
|
||||
else:
|
||||
result.append(vspace.aggregate_fn(g))
|
||||
return result
|
||||
|
||||
|
||||
# TODO(agarwal): use an automatic mechanism for handling None arguments to
|
||||
# gradient functions.
|
||||
# Some gradient functions can accept None arguments for gradients. The following
|
||||
# maps the operation name to the indices at which the corresponding gradient
|
||||
# function can accept None values.
|
||||
# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
|
||||
# during backprop. However the gradient function uses only the first of those
|
||||
# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
|
||||
# indicates that only the gradient corresponding to index 0 is used, and the
|
||||
# gradient values at indices 1-4 are ignored (and hence can be None). The
|
||||
# backprop algorithm can then leverage this by not constructing zeros to
|
||||
# pass for those indices.
|
||||
_grad_fn_accepts_none_for_indices = {
|
||||
"SoftmaxCrossEntropyWithLogits": [1],
|
||||
"FusedBatchNorm": [1, 2, 3, 4]
|
||||
}
|
@ -675,7 +675,7 @@ class _EagerTensorBase(Tensor):
|
||||
if not context.in_graph_mode():
|
||||
self_device = self.device
|
||||
def grad_fun(dresult):
|
||||
return dresult._copy(device_name=self_device)
|
||||
return [dresult._copy(device_name=self_device)]
|
||||
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user