Accept output gradients of side outputs when calling functions
Fixes higher-order gradients of function calls When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested. When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request. PiperOrigin-RevId: 256268751
This commit is contained in:
parent
aabcdcbdff
commit
bd4feec252
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import functools
|
||||
import itertools
|
||||
import threading
|
||||
@ -40,13 +41,16 @@ from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import default_gradient
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -390,7 +394,8 @@ class _EagerDefinedFunction(object):
|
||||
self._output_types = [o.type for o in self.signature.output_arg]
|
||||
self._output_shapes = [o.shape for o in outputs]
|
||||
self._control_captures = graph.control_captures
|
||||
self._func_graph_outputs = outputs
|
||||
# Shallow copy outputs since ConcreteFunction may mutate it.
|
||||
self._func_graph_outputs = list(outputs)
|
||||
self.grad_func_name = None
|
||||
self.python_grad_func = None
|
||||
self._c_func = c_api_util.ScopedTFFunction(fn)
|
||||
@ -481,6 +486,13 @@ class _EagerDefinedFunction(object):
|
||||
return outputs
|
||||
|
||||
|
||||
class _PossibleTapeGradientTypes(enum.Enum):
|
||||
"""Represents the output of TFE_Py_TapeSetPossibleGradientTypes."""
|
||||
NONE = 0
|
||||
FIRST_ORDER = 1
|
||||
HIGHER_ORDER = 2
|
||||
|
||||
|
||||
class ConcreteFunction(object):
|
||||
"""Callable object encapsulating a function definition and its gradient.
|
||||
|
||||
@ -517,7 +529,27 @@ class ConcreteFunction(object):
|
||||
self._inference_function = _EagerDefinedFunction(
|
||||
_inference_name(self._func_graph.name), self._func_graph,
|
||||
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
|
||||
self._backward_graph_function = None
|
||||
|
||||
# When graph building without a tape active, symbolic gradients rely on
|
||||
# regenerating the backward function for higher-order gradients (to account
|
||||
# for new side outputs of the rewritten forward function call). Thus there
|
||||
# is no fixed backward function for this case. However, when a tape is
|
||||
# active (eager or graph building), we generate fixed backward and forward
|
||||
# functions at forward function call time.
|
||||
#
|
||||
# This difference between the tape and non-tape cases is to avoid building
|
||||
# unneeded backward functions while graph building (where we may or may not
|
||||
# eventually need gradients).
|
||||
self._tape_forward_function_first_order = None
|
||||
self._tape_backward_function_first_order = None
|
||||
self._tape_forward_function_higher_order = None
|
||||
self._tape_backward_function_higher_order = None
|
||||
|
||||
# A map from the number of forward function outputs with accepted gradients
|
||||
# to backward functions, used to cache non-tape backward function
|
||||
# generation.
|
||||
self._cached_graph_backprop_functions = {}
|
||||
|
||||
self._signature = signature
|
||||
self._gradient_name = None
|
||||
|
||||
@ -673,12 +705,28 @@ class ConcreteFunction(object):
|
||||
"Tensor." % (self._func_graph.name, i, str(arg)))
|
||||
args = tensor_inputs + captured_inputs
|
||||
|
||||
if (tape.should_record(tensor_inputs) or
|
||||
tape.should_record(captured_inputs)):
|
||||
if context.executing_eagerly():
|
||||
return self._eager_backprop_call(args)
|
||||
else:
|
||||
return self._backprop_call_with_delayed_rewrite(args)
|
||||
possible_gradient_type = _PossibleTapeGradientTypes(
|
||||
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||
if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER:
|
||||
# There is a single non-persistent tape active, so the user can only
|
||||
# request first-order gradients from a tape. We can spend less time graph
|
||||
# building since we know this.
|
||||
#
|
||||
# We may still end up computing higher-order gradients, but that'd be
|
||||
# through `tf.gradients`, which can re-write the forward pass and so needs
|
||||
# no preparation here.
|
||||
forward_function, backward_function = (
|
||||
self._tape_functions_for_first_order())
|
||||
return self._tape_backprop_call(args, forward_function, backward_function)
|
||||
elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER:
|
||||
# Either there's a persistent tape watching, or there are multiple nested
|
||||
# tapes. Either way, the user may request higher-order gradients. We'll
|
||||
# spend a bit more time and make sure higher-order gradients are correct.
|
||||
forward_function, backward_function = (
|
||||
self._tape_functions_for_higher_order())
|
||||
return self._tape_backprop_call(args, forward_function, backward_function)
|
||||
# else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no
|
||||
# tape is recording.
|
||||
|
||||
# Only need to override the gradient in graph mode and when we have outputs.
|
||||
if context.executing_eagerly() or not self.outputs:
|
||||
@ -708,30 +756,39 @@ class ConcreteFunction(object):
|
||||
|
||||
def _grad_fn(self, op, *doutputs):
|
||||
"""Gradients of this function."""
|
||||
if self._backward_graph_function is None:
|
||||
self._construct_backprop_function()
|
||||
backwards_function = self._graph_backprop_function(len(doutputs))
|
||||
self._forward_function.add_to_graph(op.graph)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
self._forward_function.add_to_graph(op.graph)
|
||||
num_inference_outputs = self._inference_function._num_outputs
|
||||
|
||||
# Rewrite an inference call op to be a forward call op
|
||||
if op.get_attr("f").name.encode() == self._inference_function.name:
|
||||
op._set_func_attr("f", self._forward_function.name)
|
||||
op._set_type_list_attr("Tout", self._forward_function._output_types)
|
||||
op._add_outputs(
|
||||
self._forward_function._output_types[num_inference_outputs:],
|
||||
self._forward_function._output_shapes[num_inference_outputs:])
|
||||
for i in range(num_inference_outputs, len(op.outputs)):
|
||||
func_graph_output = self._forward_function._func_graph_outputs[i]
|
||||
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
|
||||
op._set_func_attr("f", self._forward_function.name)
|
||||
op._set_type_list_attr("Tout", self._forward_function._output_types)
|
||||
op._add_outputs(
|
||||
self._forward_function._output_types[len(op.outputs):],
|
||||
self._forward_function._output_shapes[len(op.outputs):])
|
||||
for i in range(len(op.outputs)):
|
||||
func_graph_output = self._forward_function._func_graph_outputs[i]
|
||||
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
|
||||
# pylint: enable=protected-access
|
||||
|
||||
capture_mapping = dict(zip(self._func_graph.outputs, op.outputs))
|
||||
remapped_captures = []
|
||||
for capture in backwards_function.captured_inputs:
|
||||
remapped_captures.append(capture_mapping.get(capture, capture))
|
||||
|
||||
# Replace Nones with zeros since we're calling a graph function which
|
||||
# expects numeric inputs.
|
||||
cleaned_doutputs = []
|
||||
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
|
||||
if gradients_util.IsTrainable(placeholder):
|
||||
if doutput is not None:
|
||||
cleaned_doutputs.append(doutput)
|
||||
else:
|
||||
cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
|
||||
|
||||
# Compute the gradients using the side outputs
|
||||
side_outputs = op.outputs[num_inference_outputs:]
|
||||
args = list(doutputs[:num_inference_outputs]) + list(side_outputs)
|
||||
return self._backward_graph_function._call_flat( # pylint: disable=protected-access
|
||||
(a for a in args if a is not None),
|
||||
self._backward_graph_function.captured_inputs)
|
||||
return backwards_function._call_flat( # pylint: disable=protected-access
|
||||
cleaned_doutputs, remapped_captures)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -820,16 +877,190 @@ class ConcreteFunction(object):
|
||||
# 2. Otherwise, defun will create two functions, one for forward pass,
|
||||
# and the backward pass will be created via tape.
|
||||
# When registering the function, we register both cases.
|
||||
if self._backward_graph_function is None:
|
||||
self._construct_backprop_function()
|
||||
backward_function = self._graph_backprop_function()._inference_function
|
||||
forward_function = self._forward_function
|
||||
backward_function = self._backward_graph_function._inference_function
|
||||
# pylint: enable=protected-access
|
||||
forward_function.add_to_graph(g)
|
||||
backward_function.add_to_graph(g)
|
||||
|
||||
def _construct_backprop_function(self):
|
||||
"""Constructs the backprop function object for this function."""
|
||||
def _graph_backprop_function(self, num_doutputs=None):
|
||||
"""A possibly-cached backprop function."""
|
||||
backward_function = self._cached_graph_backprop_functions.get(
|
||||
num_doutputs, None)
|
||||
if backward_function is not None:
|
||||
return backward_function
|
||||
backward_function = self._construct_graph_backprop_function(num_doutputs)
|
||||
self._cached_graph_backprop_functions[num_doutputs] = backward_function
|
||||
return backward_function
|
||||
|
||||
def _construct_graph_backprop_function(self, num_doutputs=None):
|
||||
"""Constructs a backprop function object for this function.
|
||||
|
||||
Args:
|
||||
num_doutputs: The constructed backprop function will take output gradients
|
||||
for the first `num_doutputs` outputs of the forward function. Defaults
|
||||
to the number of outputs for the inference function, but when
|
||||
higher-order gradients are computed this will increase to include side
|
||||
outputs.
|
||||
|
||||
Returns:
|
||||
A backward function taking `num_doutputs` arguments and returning
|
||||
gradients with respect to inputs of the forward function.
|
||||
|
||||
self._forward_function is re-generated to account for new side outputs, if
|
||||
any extra were required when building the backward pass.
|
||||
"""
|
||||
if num_doutputs is None:
|
||||
num_doutputs = len(self._inference_function.signature.output_arg)
|
||||
trainable_outputs = [
|
||||
output for output in self._func_graph.outputs[:num_doutputs]
|
||||
if gradients_util.IsTrainable(output)]
|
||||
|
||||
signature = []
|
||||
for t in trainable_outputs:
|
||||
signature.append(
|
||||
tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
|
||||
|
||||
def _backprop_function(*grad_ys):
|
||||
return gradients_util._GradientsHelper( # pylint: disable=protected-access
|
||||
trainable_outputs,
|
||||
self._func_graph.inputs,
|
||||
grad_ys=grad_ys,
|
||||
src_graph=self._func_graph)
|
||||
|
||||
with self._func_graph.as_default():
|
||||
backwards_graph = func_graph_module.FuncGraph(
|
||||
_backward_name(self._func_graph.name))
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
name=backwards_graph.name,
|
||||
python_func=_backprop_function,
|
||||
args=[], kwargs={},
|
||||
signature=signature,
|
||||
func_graph=backwards_graph)
|
||||
backwards_graph_captures = list(backwards_graph.captures.keys())
|
||||
captures_from_forward = [
|
||||
c for c in backwards_graph_captures if
|
||||
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||
|
||||
forward_function_name = _forward_name(self._func_graph.name)
|
||||
|
||||
existing_outputs = set(self._func_graph.outputs)
|
||||
for capture in captures_from_forward:
|
||||
if capture not in existing_outputs:
|
||||
existing_outputs.add(capture)
|
||||
self._func_graph.outputs.append(capture)
|
||||
backward_function_attr = _parse_func_attrs(
|
||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||
backward_function_attr.update(self._attrs)
|
||||
|
||||
backward_function = ConcreteFunction(
|
||||
backwards_graph, attrs=backward_function_attr)
|
||||
forward_function_attr = _parse_func_attrs({
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||
backward_function._inference_function.name}) # pylint: disable=protected-access
|
||||
forward_function_attr.update(self._attrs)
|
||||
|
||||
self._forward_function = _EagerDefinedFunction(
|
||||
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||
self._func_graph.outputs, forward_function_attr)
|
||||
return backward_function
|
||||
|
||||
def _tape_functions_for_first_order(self):
|
||||
"""Shortcut for when only first-order gradients are required.
|
||||
|
||||
The returned backward function does not accept gradients with respect to
|
||||
side output of forward_function. This is fine as long as the user can't
|
||||
possibly request second order tape gradients, as when they've used a single
|
||||
non-persistent GradientTape. Since we don't need the backward function to
|
||||
take gradients with respect to side outputs, we can skip some potentially
|
||||
slow graph building.
|
||||
|
||||
Returns:
|
||||
A tuple of (forward_function, backward_function):
|
||||
forward_function: Takes the same inputs as the inference function, but
|
||||
returns side outputs used by backward_function in addition to the
|
||||
inference function's outputs.
|
||||
backward_function: Takes side outputs from forward_function and
|
||||
gradients with respect to the "real" outputs of forward_function and
|
||||
returns gradients with respect to the inputs.
|
||||
"""
|
||||
if self._tape_forward_function_first_order is not None:
|
||||
return (self._tape_forward_function_first_order,
|
||||
self._tape_backward_function_first_order)
|
||||
outputs = self._func_graph.outputs[
|
||||
:len(self._inference_function.signature.output_arg)]
|
||||
forward_function, backward_function = (
|
||||
self._tape_forward_and_backward_functions(outputs))
|
||||
self._tape_forward_function_first_order = forward_function
|
||||
self._tape_backward_function_first_order = backward_function
|
||||
return forward_function, backward_function
|
||||
|
||||
# TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
|
||||
# generalizing if so.
|
||||
def _tape_functions_for_higher_order(self):
|
||||
"""Forward and backward functions suitable for higher-order gradients.
|
||||
|
||||
Unlike `_tape_functions_for_first_order`, the backward function built by
|
||||
this method accepts gradients for all of the outputs of the returned forward
|
||||
function, including side outputs.
|
||||
|
||||
Returns:
|
||||
A tuple of (forward_function, backward_function):
|
||||
forward_function: Takes the same inputs as the inference function, but
|
||||
returns side outputs used by backward_function in addition to the
|
||||
inference function's outputs.
|
||||
backward_function: Takes side outputs from forward_function and
|
||||
gradients with respect to all of its outputs, real and side. Returns
|
||||
gradients with respect to the inputs.
|
||||
"""
|
||||
if self._tape_forward_function_higher_order is not None:
|
||||
return (self._tape_forward_function_higher_order,
|
||||
self._tape_backward_function_higher_order)
|
||||
outputs = []
|
||||
# First we need to figure out how many side outputs from the forward pass
|
||||
# will be required. We do this in a temporary graph to avoid actually
|
||||
# running multiple copies of the backward pass (one per _GradientsHelper
|
||||
# call).
|
||||
#
|
||||
# While computing gradients, the backward function captures Tensors from
|
||||
# the forward function. We add these as side outputs of the original
|
||||
# function. However, we then need to accept output gradients with respect
|
||||
# to these side outputs for higher order gradients to work. Thus we loop
|
||||
# until the number of outputs of the function stabilizes. Note that this
|
||||
# is only required for tape gradients, where we need to declare in advance
|
||||
# all of the forward op's outputs: symbolic gradients with tf.gradients
|
||||
# instead rely on regenerating backward functions when higher-order
|
||||
# gradients are requested.
|
||||
while len(outputs) < len(self._func_graph.outputs):
|
||||
new_outputs = self._func_graph.outputs[len(outputs):]
|
||||
outputs = list(self._func_graph.outputs)
|
||||
self._tape_forward_and_backward_functions(new_outputs)
|
||||
forward_function, backward_function = (
|
||||
self._tape_forward_and_backward_functions(outputs))
|
||||
if len(self._func_graph.outputs) != len(outputs):
|
||||
raise AssertionError(
|
||||
("Unexpectedly added new outputs to the forward function when "
|
||||
"building the backward function: {}").format(
|
||||
self._func_graph.outputs[len(outputs):]))
|
||||
self._tape_forward_function_higher_order = forward_function
|
||||
self._tape_backward_function_higher_order = backward_function
|
||||
return forward_function, backward_function
|
||||
|
||||
def _tape_forward_and_backward_functions(self, outputs):
|
||||
"""Constructs tape forward and back functions for `outputs`."""
|
||||
# First figure out which of `outputs` are trainable. We'll accept gradients
|
||||
# for each of these in the backward function.
|
||||
handles_to_variables = {self._func_graph.captures[v.handle]: v
|
||||
for v in self._func_graph.variables
|
||||
if v.handle in self._func_graph.captures}
|
||||
trainable_outputs = []
|
||||
for output in outputs:
|
||||
if gradients_util.IsTrainable(output):
|
||||
# Swap in the Variable object for resource handles if we can so
|
||||
# sparse gradients work.
|
||||
output = handles_to_variables.get(output, output)
|
||||
trainable_outputs.append(output)
|
||||
|
||||
backwards_graph = func_graph_module.FuncGraph(
|
||||
_backward_name(self._func_graph.name))
|
||||
# Keep track of the forward graph so that if the backwards graph
|
||||
@ -837,73 +1068,79 @@ class ConcreteFunction(object):
|
||||
# the forward graph. This is an edge case that can only happen with
|
||||
# tf.custom_gradient.
|
||||
backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access
|
||||
forward_function_name = _forward_name(self._func_graph.name)
|
||||
outputs = [x for x in self._func_graph.outputs
|
||||
if gradients_util.IsTrainable(x)]
|
||||
with backwards_graph.as_default():
|
||||
gradients_wrt_outputs = [
|
||||
graph_placeholder(x.dtype, x.shape) for x in outputs
|
||||
]
|
||||
gradients_wrt_outputs = []
|
||||
for output in trainable_outputs:
|
||||
gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
|
||||
output)
|
||||
gradients_wrt_outputs.append(
|
||||
graph_placeholder(gradient_dtype, gradient_shape))
|
||||
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
|
||||
outputs,
|
||||
trainable_outputs,
|
||||
self._func_graph.inputs,
|
||||
grad_ys=gradients_wrt_outputs,
|
||||
src_graph=self._func_graph)
|
||||
|
||||
backwards_graph_captures = list(backwards_graph.captures.keys())
|
||||
captures_from_forward = [
|
||||
c for c in backwards_graph.captures.keys() if
|
||||
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||
existing_outputs = set(self._func_graph.outputs)
|
||||
for capture in captures_from_forward:
|
||||
if capture not in existing_outputs:
|
||||
existing_outputs.add(capture)
|
||||
self._func_graph.outputs.append(capture)
|
||||
|
||||
forward_function_name = _forward_name(self._func_graph.name)
|
||||
backward_function_attr = _parse_func_attrs(
|
||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||
backward_function_attr.update(self._attrs)
|
||||
|
||||
# The ordering of `backwards_graph.inputs` is important: inputs of
|
||||
# `self._backward_graph_function` correspond to outputs of
|
||||
# `self._forward_function`.
|
||||
backwards_graph.inputs = gradients_wrt_outputs + list(
|
||||
backwards_graph.captures.values())
|
||||
# Clear captures, since we pass them in as inputs.
|
||||
backwards_graph.captures = {}
|
||||
# `backward_function` correspond to outputs (including
|
||||
# side outputs) of `self._tape_forward_function`.
|
||||
backwards_graph.inputs = (
|
||||
gradients_wrt_outputs + list(backwards_graph.captures.values()))
|
||||
backwards_graph.outputs.extend(
|
||||
grad
|
||||
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
||||
if grad is not None)
|
||||
backwards_graph.structured_outputs = gradients_wrt_inputs
|
||||
self._backward_graph_function = ConcreteFunction(
|
||||
backward_function = ConcreteFunction(
|
||||
backwards_graph, attrs=backward_function_attr)
|
||||
|
||||
forward_function_attr = _parse_func_attrs({
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
|
||||
backward_function._inference_function.name}) # pylint: disable=protected-access
|
||||
forward_function_attr.update(self._attrs)
|
||||
self._forward_function = _EagerDefinedFunction(
|
||||
|
||||
forward_function = _EagerDefinedFunction(
|
||||
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||
self._func_graph.outputs + backwards_graph_captures,
|
||||
self._func_graph.outputs,
|
||||
forward_function_attr)
|
||||
return forward_function, backward_function
|
||||
|
||||
def _eager_backprop_call(self, args):
|
||||
def _tape_backprop_call(self, args, forward_function, backward_function):
|
||||
"""Calls the forward function and records the result on a tape.
|
||||
|
||||
This method fully constructs the forward and backward functions before
|
||||
calling the function and recording them on the tape.
|
||||
|
||||
(Only records results on a tape if the function has outputs).
|
||||
|
||||
Args:
|
||||
args: All inputs to the function, including resolved captured inputs
|
||||
forward_function: The forward pass, outputting both user-specified and
|
||||
side outputs.
|
||||
backward_function: Computes gradients for inputs of forward_function given
|
||||
output gradients for the first `N` of forward_function's outputs, not
|
||||
necessarily all of them. See `_tape_functions_for_first_order` and
|
||||
`_tape_functions_for_higher_order`.
|
||||
|
||||
Returns:
|
||||
The call output.
|
||||
"""
|
||||
if self._backward_graph_function is None:
|
||||
self._construct_backprop_function()
|
||||
|
||||
ctx = context.context()
|
||||
|
||||
self._register_gradient()
|
||||
with ops.get_default_graph().gradient_override_map(
|
||||
{"PartitionedCall": self._gradient_name,
|
||||
"StatefulPartitionedCall": self._gradient_name}):
|
||||
outputs = self._forward_function.call(ctx, args)
|
||||
outputs = forward_function.call(ctx, args)
|
||||
|
||||
if isinstance(outputs, ops.Operation) or outputs is None:
|
||||
return outputs
|
||||
@ -912,19 +1149,55 @@ class ConcreteFunction(object):
|
||||
# `side_outputs` are the intermediate Tensors that were added as outputs to
|
||||
# the forward graph function so that we can compute its gradient.
|
||||
real_outputs = outputs[:self._num_outputs]
|
||||
skip_positions = [i for i, t in enumerate(real_outputs)
|
||||
if not gradients_util.IsTrainable(t)]
|
||||
side_outputs = outputs[self._num_outputs:]
|
||||
|
||||
def backward_function(*args):
|
||||
args = [a for i, a in enumerate(args)
|
||||
if a is not None and i not in skip_positions]
|
||||
return self._backward_graph_function._call_flat( # pylint: disable=protected-access
|
||||
list(args) + side_outputs,
|
||||
self._backward_graph_function.captured_inputs)
|
||||
capture_mapping = dict(zip(self._func_graph.outputs, outputs))
|
||||
remapped_captures = [
|
||||
capture_mapping.get(capture, capture)
|
||||
for capture in backward_function.captured_inputs]
|
||||
# We may need to use zeros_like to get a zero for variant Tensors with
|
||||
# unconnected gradients. We do that in advance so we don't have to hold on
|
||||
# to the outputs themselves, which may not be needed otherwise.
|
||||
variant_zeros_like = {}
|
||||
backward_function_inputs = (
|
||||
len(backward_function.inputs) - len(backward_function.captured_inputs))
|
||||
recorded_outputs = []
|
||||
trainable_recorded_outputs = 0
|
||||
skip_positions = []
|
||||
for output_index, output in enumerate(outputs):
|
||||
if trainable_recorded_outputs < backward_function_inputs:
|
||||
recorded_outputs.append(output)
|
||||
if gradients_util.IsTrainable(output):
|
||||
trainable_recorded_outputs += 1
|
||||
else:
|
||||
skip_positions.append(output_index)
|
||||
if output.dtype == dtypes.variant:
|
||||
variant_zeros_like[output_index] = default_gradient.zeros_like(output)
|
||||
|
||||
tape.record_operation(self._forward_function.signature.name, real_outputs,
|
||||
args, backward_function)
|
||||
def _backward_function_wrapper(*args):
|
||||
"""Process output gradients and call the backward function."""
|
||||
processed_args = []
|
||||
input_index = 0
|
||||
for output_index, arg in enumerate(args):
|
||||
if output_index in skip_positions:
|
||||
continue
|
||||
if arg is None:
|
||||
# We're calling a (non-polymorphic) ConcreteFunction, so we need to
|
||||
# have a Tensor value for each Tensor we thought would be trainable
|
||||
# based on its dtype, even if it ended up being unconnected.
|
||||
input_placeholder = backward_function.inputs[
|
||||
input_index]
|
||||
if input_placeholder.dtype == dtypes.variant:
|
||||
arg = variant_zeros_like[output_index]
|
||||
else:
|
||||
arg = array_ops.zeros(
|
||||
*default_gradient.shape_and_dtype(input_placeholder))
|
||||
processed_args.append(arg)
|
||||
input_index += 1
|
||||
return backward_function._call_flat( # pylint: disable=protected-access
|
||||
processed_args, remapped_captures)
|
||||
|
||||
tape.record_operation(forward_function.signature.name,
|
||||
recorded_outputs, args, _backward_function_wrapper)
|
||||
return self._build_call_outputs(real_outputs)
|
||||
|
||||
def _backprop_call_with_delayed_rewrite(self, args):
|
||||
|
@ -40,6 +40,13 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
_COS_DERIVATIVES = [math_ops.cos,
|
||||
lambda x: -math_ops.sin(x),
|
||||
lambda x: -math_ops.cos(x),
|
||||
math_ops.sin,
|
||||
math_ops.cos]
|
||||
|
||||
|
||||
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testGraphModeWithGradients(self):
|
||||
@ -68,6 +75,145 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(grads.eval(), 2.0)
|
||||
self.assertEqual(grads.shape, v.shape)
|
||||
|
||||
def testSymbolicHigherOrder(self):
|
||||
@def_function.function
|
||||
def f(x, order):
|
||||
y = def_function.function(lambda: math_ops.cos(x))()
|
||||
for _ in range(order):
|
||||
y, = gradients_impl.gradients(y, [x])
|
||||
return y
|
||||
for order, expected in enumerate(_COS_DERIVATIVES):
|
||||
self.assertAllClose(
|
||||
expected(constant_op.constant(1.)),
|
||||
f(constant_op.constant(1.), order))
|
||||
|
||||
@parameterized.parameters([dict(persistent=True),
|
||||
dict(persistent=False)])
|
||||
def testSymbolicHigherOrderUnderTape(self, persistent):
|
||||
@def_function.function
|
||||
def f(x, order):
|
||||
with backprop.GradientTape(persistent=persistent) as tape:
|
||||
tape.watch(x)
|
||||
# Note that having a tape active, even if we don't use it, forces us
|
||||
# down a different function call path. Symbolic gradients should work
|
||||
# here too; correctness of tape gradients are tested elsewhere.
|
||||
y = def_function.function(lambda: math_ops.cos(x))()
|
||||
tape_dy = tape.gradient(y, x)
|
||||
for _ in range(order):
|
||||
y, = gradients_impl.gradients(y, [x])
|
||||
if order > 0:
|
||||
y1 = tape_dy
|
||||
for _ in range(order - 1):
|
||||
y1, = gradients_impl.gradients(y1, [x])
|
||||
else:
|
||||
y1 = y
|
||||
return y, y1
|
||||
for order, expected_f in enumerate(_COS_DERIVATIVES):
|
||||
expected = self.evaluate(expected_f(constant_op.constant(1.)))
|
||||
self.assertAllClose(
|
||||
(expected, expected),
|
||||
f(constant_op.constant(1.), order))
|
||||
|
||||
def testIteratedGradientsNested(self):
|
||||
|
||||
def _grad(f):
|
||||
def _grad_function(primal):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primal)
|
||||
primal_out = f(primal)
|
||||
return tape.gradient(primal_out, primal)
|
||||
return _grad_function
|
||||
|
||||
@def_function.function
|
||||
def _forward(x):
|
||||
return math_ops.cos(x)
|
||||
|
||||
f = _forward
|
||||
traced_f = def_function.function(f)
|
||||
one = constant_op.constant(1.)
|
||||
for expected in _COS_DERIVATIVES:
|
||||
self.assertAllClose(expected(one), f(one))
|
||||
self.assertAllClose(expected(one), traced_f(one))
|
||||
self.assertAllClose(expected(one), def_function.function(f)(one))
|
||||
f = _grad(f)
|
||||
traced_f = def_function.function(_grad(traced_f))
|
||||
|
||||
def testIteratedGradientsNestedWithVariable(self):
|
||||
|
||||
def _grad(f):
|
||||
def _grad_function():
|
||||
with backprop.GradientTape() as tape:
|
||||
primal_out = f()
|
||||
g, = tape.gradient(primal_out, tape.watched_variables())
|
||||
return g
|
||||
return _grad_function
|
||||
|
||||
v = variables.Variable(2.)
|
||||
|
||||
@def_function.function
|
||||
def _forward():
|
||||
return math_ops.cos(v)
|
||||
|
||||
f = _forward
|
||||
|
||||
two = constant_op.constant(2.)
|
||||
|
||||
for expected in _COS_DERIVATIVES:
|
||||
self.assertAllClose(expected(two), f())
|
||||
self.assertAllClose(expected(two), def_function.function(f)())
|
||||
f = _grad(f)
|
||||
|
||||
def testIteratedGradientsPersistent(self):
|
||||
|
||||
@def_function.function
|
||||
def _forward(z):
|
||||
return math_ops.cos(z)
|
||||
|
||||
f = _forward
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
start = constant_op.constant(1.)
|
||||
tape.watch(start)
|
||||
x = f(start)
|
||||
for expected in _COS_DERIVATIVES:
|
||||
self.assertAllClose(expected(start), x)
|
||||
x = tape.gradient(x, start)
|
||||
|
||||
def testHigherOrderWithVariable(self):
|
||||
|
||||
v = variables.Variable(1.)
|
||||
|
||||
@def_function.function
|
||||
def _forward():
|
||||
return math_ops.cos(v)
|
||||
|
||||
f = _forward
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = f()
|
||||
for expected in _COS_DERIVATIVES:
|
||||
self.assertAllClose(expected(constant_op.constant(1.)), x)
|
||||
x, = tape.gradient(x, tape.watched_variables())
|
||||
|
||||
def testGradientsChained(self):
|
||||
|
||||
@def_function.function
|
||||
def _forward(z):
|
||||
return math_ops.cos(z)
|
||||
|
||||
f = _forward
|
||||
x = constant_op.constant(1.)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(x)
|
||||
y = f(x)
|
||||
with backprop.GradientTape() as tt:
|
||||
doutputs = constant_op.constant(2.)
|
||||
tt.watch(doutputs)
|
||||
g = t.gradient(y, x, doutputs)
|
||||
self.assertAllClose(-2. * math_ops.sin(x), g)
|
||||
gg = tt.gradient(g, doutputs)
|
||||
# We're taking gradients with respect to doutputs, which is just a linear
|
||||
# function of the gradient.
|
||||
self.assertAllClose(-math_ops.sin(x), gg)
|
||||
|
||||
def testSymGradGatherNd(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
|
||||
|
@ -148,6 +148,13 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
|
||||
PyObject* TFE_Py_TapeSetIsEmpty();
|
||||
|
||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
|
||||
|
||||
// Like TFE_Py_TapeSetShouldRecord but with a ternary return:
|
||||
// - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecord is false)
|
||||
// - 1 if first-order gradients may be requested
|
||||
// - 2 if higher-order gradients may be requested
|
||||
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors);
|
||||
|
||||
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
|
||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
|
||||
|
||||
|
@ -695,6 +695,14 @@ void SetOpAttrWithDefaults(
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* GetPythonObjectFromInt(int num) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyLong_FromLong(num);
|
||||
#else
|
||||
return PyInt_FromLong(num);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Python subclass of Exception that is created on not ok Status.
|
||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||
@ -1500,33 +1508,51 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
|
||||
return tensor_ids;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
|
||||
// null. Returns true on success and false on a Python exception.
|
||||
bool TensorShapesAndDtypes(PyObject* tensors,
|
||||
std::vector<tensorflow::int64>* tensor_ids,
|
||||
std::vector<tensorflow::DataType>* dtypes) {
|
||||
tensorflow::Safe_PyObjectPtr seq(
|
||||
PySequence_Fast(tensors, "expected a sequence"));
|
||||
if (seq == nullptr) {
|
||||
return false;
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(seq.get());
|
||||
tensor_ids->reserve(len);
|
||||
dtypes->reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
|
||||
tensor_ids->push_back(FastTensorId(item));
|
||||
dtypes->push_back(FastTensorDtype(item));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TapeCouldPossiblyRecord(PyObject* tensors) {
|
||||
if (tensors == Py_None) {
|
||||
Py_RETURN_FALSE;
|
||||
return false;
|
||||
}
|
||||
if (*ThreadTapeIsStopped()) {
|
||||
Py_RETURN_FALSE;
|
||||
return false;
|
||||
}
|
||||
if (!HasTape()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||
if (seq == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(seq);
|
||||
// TODO(apassos) consider not building a list and changing the API to check
|
||||
// each tensor individually.
|
||||
std::vector<tensorflow::int64> tensor_ids;
|
||||
std::vector<tensorflow::DataType> dtypes;
|
||||
tensor_ids.reserve(len);
|
||||
dtypes.reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||
tensor_ids.push_back(FastTensorId(item));
|
||||
dtypes.push_back(FastTensorDtype(item));
|
||||
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
||||
return nullptr;
|
||||
}
|
||||
Py_DECREF(seq);
|
||||
auto tape_set = *GetTapeSet();
|
||||
for (TFE_Py_Tape* tape : tape_set) {
|
||||
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||
@ -1543,6 +1569,53 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
||||
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||
return GetPythonObjectFromInt(0);
|
||||
}
|
||||
std::vector<tensorflow::int64> tensor_ids;
|
||||
std::vector<tensorflow::DataType> dtypes;
|
||||
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If there is a persistent tape watching, or if there are multiple tapes
|
||||
// watching, we'll return immediately indicating that higher-order tape
|
||||
// gradients are possible.
|
||||
bool some_tape_watching = false;
|
||||
auto tape_set = *GetTapeSet();
|
||||
for (TFE_Py_Tape* tape : tape_set) {
|
||||
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||
if (tape->tape->IsPersistent() || some_tape_watching) {
|
||||
// Either this is the second tape watching, or this tape is persistent:
|
||||
// higher-order gradients are possible.
|
||||
return GetPythonObjectFromInt(2);
|
||||
}
|
||||
some_tape_watching = true;
|
||||
}
|
||||
}
|
||||
auto forward_accumulators = *GetAccumulatorSet();
|
||||
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
|
||||
if (some_tape_watching) {
|
||||
// This is the second tape watching: higher-order gradients are
|
||||
// possible. Note that there's no equivalent of persistence for
|
||||
// forward-mode.
|
||||
return GetPythonObjectFromInt(2);
|
||||
}
|
||||
some_tape_watching = true;
|
||||
}
|
||||
}
|
||||
if (some_tape_watching) {
|
||||
// There's exactly one non-persistent tape. The user can request first-order
|
||||
// gradients but won't be able to get higher-order tape gradients.
|
||||
return GetPythonObjectFromInt(1);
|
||||
} else {
|
||||
// There are no tapes. The user can't request tape gradients.
|
||||
return GetPythonObjectFromInt(0);
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
|
||||
if (*ThreadTapeIsStopped()) {
|
||||
return;
|
||||
@ -1997,14 +2070,6 @@ PyObject* GetPythonObjectFromString(const char* s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* GetPythonObjectFromInt(int num) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyLong_FromLong(num);
|
||||
#else
|
||||
return PyInt_FromLong(num);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CheckResourceVariable(PyObject* item) {
|
||||
if (PyObject_TypeCheck(item, resource_variable_type)) {
|
||||
tensorflow::Safe_PyObjectPtr handle(
|
||||
|
@ -18,6 +18,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
|
||||
@ -33,3 +35,25 @@ def get_zeros_dtype(t):
|
||||
else:
|
||||
return handle_data.shape_and_type[0].dtype
|
||||
return t.dtype
|
||||
|
||||
|
||||
def shape_and_dtype(t):
|
||||
"""Return the shape and dtype for the default gradient for a Tensor."""
|
||||
if t.dtype == dtypes.resource:
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||
if (handle_data is None or not handle_data.is_set or
|
||||
len(handle_data.shape_and_type) != 1):
|
||||
return tensor_shape.TensorShape(None), dtypes.float32
|
||||
else:
|
||||
shape_and_type = handle_data.shape_and_type[0]
|
||||
return (tensor_shape.TensorShape(shape_and_type.shape),
|
||||
dtypes.as_dtype(shape_and_type.dtype))
|
||||
return t.shape, t.dtype
|
||||
|
||||
|
||||
def zeros_like(t):
|
||||
"""Like array_ops.zeros_like, but respects resource handles."""
|
||||
if t.dtype == dtypes.resource:
|
||||
return array_ops.zeros(*shape_and_dtype(t))
|
||||
else:
|
||||
return array_ops.zeros_like(t)
|
||||
|
@ -74,6 +74,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_Py_TapeSetIsStopped;
|
||||
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
||||
%rename("%s") TFE_Py_TapeSetShouldRecord;
|
||||
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
|
||||
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
||||
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
||||
%rename("%s") TFE_Py_TapeGradient;
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import function_deserialization
|
||||
@ -174,6 +175,18 @@ class Loader(object):
|
||||
for bound_input, internal_capture in zip(
|
||||
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
||||
concrete_function.graph.captures[bound_input] = internal_capture
|
||||
if internal_capture.dtype == dtypes.resource:
|
||||
if resource_variable_ops.is_resource_variable(bound_input):
|
||||
try:
|
||||
handle = bound_input.handle
|
||||
except ValueError:
|
||||
# For mirrored variables we'll copy handle data for components
|
||||
# as they get captured.
|
||||
pass
|
||||
else:
|
||||
custom_gradient.copy_handle_data(handle, internal_capture)
|
||||
else:
|
||||
custom_gradient.copy_handle_data(bound_input, internal_capture)
|
||||
# Setting "captures" first means "capture" won't create a new
|
||||
# placeholder for this input.
|
||||
concrete_function.graph.capture(bound_input)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -172,14 +173,21 @@ class CapturableResourceDeleter(object):
|
||||
|
||||
def __init__(self, destroy_resource_fn=None):
|
||||
if destroy_resource_fn:
|
||||
self.destroy_resource = destroy_resource_fn
|
||||
self._destroy_resource = destroy_resource_fn
|
||||
self._destruction_context = (
|
||||
context.eager_mode if context.executing_eagerly()
|
||||
else ops.get_default_graph().as_default)
|
||||
else:
|
||||
self._destroy_resource = None
|
||||
|
||||
def destroy_resource(self):
|
||||
"""A function that destroys the resource."""
|
||||
pass
|
||||
if self._destroy_resource:
|
||||
return self._destroy_resource()
|
||||
|
||||
def __del__(self):
|
||||
self.destroy_resource()
|
||||
if self._destroy_resource:
|
||||
with self._destruction_context():
|
||||
self._destroy_resource()
|
||||
|
||||
|
||||
class CapturableResource(base.Trackable):
|
||||
|
Loading…
Reference in New Issue
Block a user