Accept output gradients of side outputs when calling functions

Fixes higher-order gradients of function calls

When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested.

When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request.

PiperOrigin-RevId: 256268751
This commit is contained in:
Allen Lavoie 2019-07-02 17:02:39 -07:00 committed by TensorFlower Gardener
parent aabcdcbdff
commit bd4feec252
8 changed files with 635 additions and 98 deletions

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import enum # pylint: disable=g-bad-import-order
import functools import functools
import itertools import itertools
import threading import threading
@ -40,13 +41,16 @@ from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_util from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
@ -390,7 +394,8 @@ class _EagerDefinedFunction(object):
self._output_types = [o.type for o in self.signature.output_arg] self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs] self._output_shapes = [o.shape for o in outputs]
self._control_captures = graph.control_captures self._control_captures = graph.control_captures
self._func_graph_outputs = outputs # Shallow copy outputs since ConcreteFunction may mutate it.
self._func_graph_outputs = list(outputs)
self.grad_func_name = None self.grad_func_name = None
self.python_grad_func = None self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn) self._c_func = c_api_util.ScopedTFFunction(fn)
@ -481,6 +486,13 @@ class _EagerDefinedFunction(object):
return outputs return outputs
class _PossibleTapeGradientTypes(enum.Enum):
"""Represents the output of TFE_Py_TapeSetPossibleGradientTypes."""
NONE = 0
FIRST_ORDER = 1
HIGHER_ORDER = 2
class ConcreteFunction(object): class ConcreteFunction(object):
"""Callable object encapsulating a function definition and its gradient. """Callable object encapsulating a function definition and its gradient.
@ -517,7 +529,27 @@ class ConcreteFunction(object):
self._inference_function = _EagerDefinedFunction( self._inference_function = _EagerDefinedFunction(
_inference_name(self._func_graph.name), self._func_graph, _inference_name(self._func_graph.name), self._func_graph,
self._func_graph.inputs, self._func_graph.outputs, self._attrs) self._func_graph.inputs, self._func_graph.outputs, self._attrs)
self._backward_graph_function = None
# When graph building without a tape active, symbolic gradients rely on
# regenerating the backward function for higher-order gradients (to account
# for new side outputs of the rewritten forward function call). Thus there
# is no fixed backward function for this case. However, when a tape is
# active (eager or graph building), we generate fixed backward and forward
# functions at forward function call time.
#
# This difference between the tape and non-tape cases is to avoid building
# unneeded backward functions while graph building (where we may or may not
# eventually need gradients).
self._tape_forward_function_first_order = None
self._tape_backward_function_first_order = None
self._tape_forward_function_higher_order = None
self._tape_backward_function_higher_order = None
# A map from the number of forward function outputs with accepted gradients
# to backward functions, used to cache non-tape backward function
# generation.
self._cached_graph_backprop_functions = {}
self._signature = signature self._signature = signature
self._gradient_name = None self._gradient_name = None
@ -673,12 +705,28 @@ class ConcreteFunction(object):
"Tensor." % (self._func_graph.name, i, str(arg))) "Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs args = tensor_inputs + captured_inputs
if (tape.should_record(tensor_inputs) or possible_gradient_type = _PossibleTapeGradientTypes(
tape.should_record(captured_inputs)): pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
if context.executing_eagerly(): if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER:
return self._eager_backprop_call(args) # There is a single non-persistent tape active, so the user can only
else: # request first-order gradients from a tape. We can spend less time graph
return self._backprop_call_with_delayed_rewrite(args) # building since we know this.
#
# We may still end up computing higher-order gradients, but that'd be
# through `tf.gradients`, which can re-write the forward pass and so needs
# no preparation here.
forward_function, backward_function = (
self._tape_functions_for_first_order())
return self._tape_backprop_call(args, forward_function, backward_function)
elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER:
# Either there's a persistent tape watching, or there are multiple nested
# tapes. Either way, the user may request higher-order gradients. We'll
# spend a bit more time and make sure higher-order gradients are correct.
forward_function, backward_function = (
self._tape_functions_for_higher_order())
return self._tape_backprop_call(args, forward_function, backward_function)
# else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no
# tape is recording.
# Only need to override the gradient in graph mode and when we have outputs. # Only need to override the gradient in graph mode and when we have outputs.
if context.executing_eagerly() or not self.outputs: if context.executing_eagerly() or not self.outputs:
@ -708,30 +756,39 @@ class ConcreteFunction(object):
def _grad_fn(self, op, *doutputs): def _grad_fn(self, op, *doutputs):
"""Gradients of this function.""" """Gradients of this function."""
if self._backward_graph_function is None: backwards_function = self._graph_backprop_function(len(doutputs))
self._construct_backprop_function() self._forward_function.add_to_graph(op.graph)
# pylint: disable=protected-access # pylint: disable=protected-access
self._forward_function.add_to_graph(op.graph)
num_inference_outputs = self._inference_function._num_outputs
# Rewrite an inference call op to be a forward call op # Rewrite an inference call op to be a forward call op
if op.get_attr("f").name.encode() == self._inference_function.name:
op._set_func_attr("f", self._forward_function.name) op._set_func_attr("f", self._forward_function.name)
op._set_type_list_attr("Tout", self._forward_function._output_types) op._set_type_list_attr("Tout", self._forward_function._output_types)
op._add_outputs( op._add_outputs(
self._forward_function._output_types[num_inference_outputs:], self._forward_function._output_types[len(op.outputs):],
self._forward_function._output_shapes[num_inference_outputs:]) self._forward_function._output_shapes[len(op.outputs):])
for i in range(num_inference_outputs, len(op.outputs)): for i in range(len(op.outputs)):
func_graph_output = self._forward_function._func_graph_outputs[i] func_graph_output = self._forward_function._func_graph_outputs[i]
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
# pylint: enable=protected-access # pylint: enable=protected-access
capture_mapping = dict(zip(self._func_graph.outputs, op.outputs))
remapped_captures = []
for capture in backwards_function.captured_inputs:
remapped_captures.append(capture_mapping.get(capture, capture))
# Replace Nones with zeros since we're calling a graph function which
# expects numeric inputs.
cleaned_doutputs = []
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
if gradients_util.IsTrainable(placeholder):
if doutput is not None:
cleaned_doutputs.append(doutput)
else:
cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
# Compute the gradients using the side outputs # Compute the gradients using the side outputs
side_outputs = op.outputs[num_inference_outputs:] return backwards_function._call_flat( # pylint: disable=protected-access
args = list(doutputs[:num_inference_outputs]) + list(side_outputs) cleaned_doutputs, remapped_captures)
return self._backward_graph_function._call_flat( # pylint: disable=protected-access
(a for a in args if a is not None),
self._backward_graph_function.captured_inputs)
@property @property
def name(self): def name(self):
@ -820,16 +877,190 @@ class ConcreteFunction(object):
# 2. Otherwise, defun will create two functions, one for forward pass, # 2. Otherwise, defun will create two functions, one for forward pass,
# and the backward pass will be created via tape. # and the backward pass will be created via tape.
# When registering the function, we register both cases. # When registering the function, we register both cases.
if self._backward_graph_function is None: backward_function = self._graph_backprop_function()._inference_function
self._construct_backprop_function()
forward_function = self._forward_function forward_function = self._forward_function
backward_function = self._backward_graph_function._inference_function
# pylint: enable=protected-access # pylint: enable=protected-access
forward_function.add_to_graph(g) forward_function.add_to_graph(g)
backward_function.add_to_graph(g) backward_function.add_to_graph(g)
def _construct_backprop_function(self): def _graph_backprop_function(self, num_doutputs=None):
"""Constructs the backprop function object for this function.""" """A possibly-cached backprop function."""
backward_function = self._cached_graph_backprop_functions.get(
num_doutputs, None)
if backward_function is not None:
return backward_function
backward_function = self._construct_graph_backprop_function(num_doutputs)
self._cached_graph_backprop_functions[num_doutputs] = backward_function
return backward_function
def _construct_graph_backprop_function(self, num_doutputs=None):
"""Constructs a backprop function object for this function.
Args:
num_doutputs: The constructed backprop function will take output gradients
for the first `num_doutputs` outputs of the forward function. Defaults
to the number of outputs for the inference function, but when
higher-order gradients are computed this will increase to include side
outputs.
Returns:
A backward function taking `num_doutputs` arguments and returning
gradients with respect to inputs of the forward function.
self._forward_function is re-generated to account for new side outputs, if
any extra were required when building the backward pass.
"""
if num_doutputs is None:
num_doutputs = len(self._inference_function.signature.output_arg)
trainable_outputs = [
output for output in self._func_graph.outputs[:num_doutputs]
if gradients_util.IsTrainable(output)]
signature = []
for t in trainable_outputs:
signature.append(
tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
def _backprop_function(*grad_ys):
return gradients_util._GradientsHelper( # pylint: disable=protected-access
trainable_outputs,
self._func_graph.inputs,
grad_ys=grad_ys,
src_graph=self._func_graph)
with self._func_graph.as_default():
backwards_graph = func_graph_module.FuncGraph(
_backward_name(self._func_graph.name))
func_graph_module.func_graph_from_py_func(
name=backwards_graph.name,
python_func=_backprop_function,
args=[], kwargs={},
signature=signature,
func_graph=backwards_graph)
backwards_graph_captures = list(backwards_graph.captures.keys())
captures_from_forward = [
c for c in backwards_graph_captures if
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
forward_function_name = _forward_name(self._func_graph.name)
existing_outputs = set(self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(self._attrs)
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function._inference_function.name}) # pylint: disable=protected-access
forward_function_attr.update(self._attrs)
self._forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs, forward_function_attr)
return backward_function
def _tape_functions_for_first_order(self):
"""Shortcut for when only first-order gradients are required.
The returned backward function does not accept gradients with respect to
side output of forward_function. This is fine as long as the user can't
possibly request second order tape gradients, as when they've used a single
non-persistent GradientTape. Since we don't need the backward function to
take gradients with respect to side outputs, we can skip some potentially
slow graph building.
Returns:
A tuple of (forward_function, backward_function):
forward_function: Takes the same inputs as the inference function, but
returns side outputs used by backward_function in addition to the
inference function's outputs.
backward_function: Takes side outputs from forward_function and
gradients with respect to the "real" outputs of forward_function and
returns gradients with respect to the inputs.
"""
if self._tape_forward_function_first_order is not None:
return (self._tape_forward_function_first_order,
self._tape_backward_function_first_order)
outputs = self._func_graph.outputs[
:len(self._inference_function.signature.output_arg)]
forward_function, backward_function = (
self._tape_forward_and_backward_functions(outputs))
self._tape_forward_function_first_order = forward_function
self._tape_backward_function_first_order = backward_function
return forward_function, backward_function
# TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
# generalizing if so.
def _tape_functions_for_higher_order(self):
"""Forward and backward functions suitable for higher-order gradients.
Unlike `_tape_functions_for_first_order`, the backward function built by
this method accepts gradients for all of the outputs of the returned forward
function, including side outputs.
Returns:
A tuple of (forward_function, backward_function):
forward_function: Takes the same inputs as the inference function, but
returns side outputs used by backward_function in addition to the
inference function's outputs.
backward_function: Takes side outputs from forward_function and
gradients with respect to all of its outputs, real and side. Returns
gradients with respect to the inputs.
"""
if self._tape_forward_function_higher_order is not None:
return (self._tape_forward_function_higher_order,
self._tape_backward_function_higher_order)
outputs = []
# First we need to figure out how many side outputs from the forward pass
# will be required. We do this in a temporary graph to avoid actually
# running multiple copies of the backward pass (one per _GradientsHelper
# call).
#
# While computing gradients, the backward function captures Tensors from
# the forward function. We add these as side outputs of the original
# function. However, we then need to accept output gradients with respect
# to these side outputs for higher order gradients to work. Thus we loop
# until the number of outputs of the function stabilizes. Note that this
# is only required for tape gradients, where we need to declare in advance
# all of the forward op's outputs: symbolic gradients with tf.gradients
# instead rely on regenerating backward functions when higher-order
# gradients are requested.
while len(outputs) < len(self._func_graph.outputs):
new_outputs = self._func_graph.outputs[len(outputs):]
outputs = list(self._func_graph.outputs)
self._tape_forward_and_backward_functions(new_outputs)
forward_function, backward_function = (
self._tape_forward_and_backward_functions(outputs))
if len(self._func_graph.outputs) != len(outputs):
raise AssertionError(
("Unexpectedly added new outputs to the forward function when "
"building the backward function: {}").format(
self._func_graph.outputs[len(outputs):]))
self._tape_forward_function_higher_order = forward_function
self._tape_backward_function_higher_order = backward_function
return forward_function, backward_function
def _tape_forward_and_backward_functions(self, outputs):
"""Constructs tape forward and back functions for `outputs`."""
# First figure out which of `outputs` are trainable. We'll accept gradients
# for each of these in the backward function.
handles_to_variables = {self._func_graph.captures[v.handle]: v
for v in self._func_graph.variables
if v.handle in self._func_graph.captures}
trainable_outputs = []
for output in outputs:
if gradients_util.IsTrainable(output):
# Swap in the Variable object for resource handles if we can so
# sparse gradients work.
output = handles_to_variables.get(output, output)
trainable_outputs.append(output)
backwards_graph = func_graph_module.FuncGraph( backwards_graph = func_graph_module.FuncGraph(
_backward_name(self._func_graph.name)) _backward_name(self._func_graph.name))
# Keep track of the forward graph so that if the backwards graph # Keep track of the forward graph so that if the backwards graph
@ -837,73 +1068,79 @@ class ConcreteFunction(object):
# the forward graph. This is an edge case that can only happen with # the forward graph. This is an edge case that can only happen with
# tf.custom_gradient. # tf.custom_gradient.
backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access
forward_function_name = _forward_name(self._func_graph.name)
outputs = [x for x in self._func_graph.outputs
if gradients_util.IsTrainable(x)]
with backwards_graph.as_default(): with backwards_graph.as_default():
gradients_wrt_outputs = [ gradients_wrt_outputs = []
graph_placeholder(x.dtype, x.shape) for x in outputs for output in trainable_outputs:
] gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
output)
gradients_wrt_outputs.append(
graph_placeholder(gradient_dtype, gradient_shape))
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
outputs, trainable_outputs,
self._func_graph.inputs, self._func_graph.inputs,
grad_ys=gradients_wrt_outputs, grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph) src_graph=self._func_graph)
backwards_graph_captures = list(backwards_graph.captures.keys()) captures_from_forward = [
c for c in backwards_graph.captures.keys() if
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
existing_outputs = set(self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
forward_function_name = _forward_name(self._func_graph.name)
backward_function_attr = _parse_func_attrs( backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(self._attrs) backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of # The ordering of `backwards_graph.inputs` is important: inputs of
# `self._backward_graph_function` correspond to outputs of # `backward_function` correspond to outputs (including
# `self._forward_function`. # side outputs) of `self._tape_forward_function`.
backwards_graph.inputs = gradients_wrt_outputs + list( backwards_graph.inputs = (
backwards_graph.captures.values()) gradients_wrt_outputs + list(backwards_graph.captures.values()))
# Clear captures, since we pass them in as inputs.
backwards_graph.captures = {}
backwards_graph.outputs.extend( backwards_graph.outputs.extend(
grad grad
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
if grad is not None) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs backwards_graph.structured_outputs = gradients_wrt_inputs
self._backward_graph_function = ConcreteFunction( backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr) backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({ forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME: BACKWARD_FUNCTION_ATTRIBUTE_NAME:
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access backward_function._inference_function.name}) # pylint: disable=protected-access
forward_function_attr.update(self._attrs) forward_function_attr.update(self._attrs)
self._forward_function = _EagerDefinedFunction(
forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs, forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs + backwards_graph_captures, self._func_graph.outputs,
forward_function_attr) forward_function_attr)
return forward_function, backward_function
def _eager_backprop_call(self, args): def _tape_backprop_call(self, args, forward_function, backward_function):
"""Calls the forward function and records the result on a tape. """Calls the forward function and records the result on a tape.
This method fully constructs the forward and backward functions before
calling the function and recording them on the tape.
(Only records results on a tape if the function has outputs).
Args: Args:
args: All inputs to the function, including resolved captured inputs args: All inputs to the function, including resolved captured inputs
forward_function: The forward pass, outputting both user-specified and
side outputs.
backward_function: Computes gradients for inputs of forward_function given
output gradients for the first `N` of forward_function's outputs, not
necessarily all of them. See `_tape_functions_for_first_order` and
`_tape_functions_for_higher_order`.
Returns: Returns:
The call output. The call output.
""" """
if self._backward_graph_function is None:
self._construct_backprop_function()
ctx = context.context() ctx = context.context()
self._register_gradient() self._register_gradient()
with ops.get_default_graph().gradient_override_map( with ops.get_default_graph().gradient_override_map(
{"PartitionedCall": self._gradient_name, {"PartitionedCall": self._gradient_name,
"StatefulPartitionedCall": self._gradient_name}): "StatefulPartitionedCall": self._gradient_name}):
outputs = self._forward_function.call(ctx, args) outputs = forward_function.call(ctx, args)
if isinstance(outputs, ops.Operation) or outputs is None: if isinstance(outputs, ops.Operation) or outputs is None:
return outputs return outputs
@ -912,19 +1149,55 @@ class ConcreteFunction(object):
# `side_outputs` are the intermediate Tensors that were added as outputs to # `side_outputs` are the intermediate Tensors that were added as outputs to
# the forward graph function so that we can compute its gradient. # the forward graph function so that we can compute its gradient.
real_outputs = outputs[:self._num_outputs] real_outputs = outputs[:self._num_outputs]
skip_positions = [i for i, t in enumerate(real_outputs)
if not gradients_util.IsTrainable(t)]
side_outputs = outputs[self._num_outputs:]
def backward_function(*args): capture_mapping = dict(zip(self._func_graph.outputs, outputs))
args = [a for i, a in enumerate(args) remapped_captures = [
if a is not None and i not in skip_positions] capture_mapping.get(capture, capture)
return self._backward_graph_function._call_flat( # pylint: disable=protected-access for capture in backward_function.captured_inputs]
list(args) + side_outputs, # We may need to use zeros_like to get a zero for variant Tensors with
self._backward_graph_function.captured_inputs) # unconnected gradients. We do that in advance so we don't have to hold on
# to the outputs themselves, which may not be needed otherwise.
variant_zeros_like = {}
backward_function_inputs = (
len(backward_function.inputs) - len(backward_function.captured_inputs))
recorded_outputs = []
trainable_recorded_outputs = 0
skip_positions = []
for output_index, output in enumerate(outputs):
if trainable_recorded_outputs < backward_function_inputs:
recorded_outputs.append(output)
if gradients_util.IsTrainable(output):
trainable_recorded_outputs += 1
else:
skip_positions.append(output_index)
if output.dtype == dtypes.variant:
variant_zeros_like[output_index] = default_gradient.zeros_like(output)
tape.record_operation(self._forward_function.signature.name, real_outputs, def _backward_function_wrapper(*args):
args, backward_function) """Process output gradients and call the backward function."""
processed_args = []
input_index = 0
for output_index, arg in enumerate(args):
if output_index in skip_positions:
continue
if arg is None:
# We're calling a (non-polymorphic) ConcreteFunction, so we need to
# have a Tensor value for each Tensor we thought would be trainable
# based on its dtype, even if it ended up being unconnected.
input_placeholder = backward_function.inputs[
input_index]
if input_placeholder.dtype == dtypes.variant:
arg = variant_zeros_like[output_index]
else:
arg = array_ops.zeros(
*default_gradient.shape_and_dtype(input_placeholder))
processed_args.append(arg)
input_index += 1
return backward_function._call_flat( # pylint: disable=protected-access
processed_args, remapped_captures)
tape.record_operation(forward_function.signature.name,
recorded_outputs, args, _backward_function_wrapper)
return self._build_call_outputs(real_outputs) return self._build_call_outputs(real_outputs)
def _backprop_call_with_delayed_rewrite(self, args): def _backprop_call_with_delayed_rewrite(self, args):

View File

@ -40,6 +40,13 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
_COS_DERIVATIVES = [math_ops.cos,
lambda x: -math_ops.sin(x),
lambda x: -math_ops.cos(x),
math_ops.sin,
math_ops.cos]
class FunctionGradientsTest(test.TestCase, parameterized.TestCase): class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
def testGraphModeWithGradients(self): def testGraphModeWithGradients(self):
@ -68,6 +75,145 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(grads.eval(), 2.0) self.assertAllEqual(grads.eval(), 2.0)
self.assertEqual(grads.shape, v.shape) self.assertEqual(grads.shape, v.shape)
def testSymbolicHigherOrder(self):
@def_function.function
def f(x, order):
y = def_function.function(lambda: math_ops.cos(x))()
for _ in range(order):
y, = gradients_impl.gradients(y, [x])
return y
for order, expected in enumerate(_COS_DERIVATIVES):
self.assertAllClose(
expected(constant_op.constant(1.)),
f(constant_op.constant(1.), order))
@parameterized.parameters([dict(persistent=True),
dict(persistent=False)])
def testSymbolicHigherOrderUnderTape(self, persistent):
@def_function.function
def f(x, order):
with backprop.GradientTape(persistent=persistent) as tape:
tape.watch(x)
# Note that having a tape active, even if we don't use it, forces us
# down a different function call path. Symbolic gradients should work
# here too; correctness of tape gradients are tested elsewhere.
y = def_function.function(lambda: math_ops.cos(x))()
tape_dy = tape.gradient(y, x)
for _ in range(order):
y, = gradients_impl.gradients(y, [x])
if order > 0:
y1 = tape_dy
for _ in range(order - 1):
y1, = gradients_impl.gradients(y1, [x])
else:
y1 = y
return y, y1
for order, expected_f in enumerate(_COS_DERIVATIVES):
expected = self.evaluate(expected_f(constant_op.constant(1.)))
self.assertAllClose(
(expected, expected),
f(constant_op.constant(1.), order))
def testIteratedGradientsNested(self):
def _grad(f):
def _grad_function(primal):
with backprop.GradientTape() as tape:
tape.watch(primal)
primal_out = f(primal)
return tape.gradient(primal_out, primal)
return _grad_function
@def_function.function
def _forward(x):
return math_ops.cos(x)
f = _forward
traced_f = def_function.function(f)
one = constant_op.constant(1.)
for expected in _COS_DERIVATIVES:
self.assertAllClose(expected(one), f(one))
self.assertAllClose(expected(one), traced_f(one))
self.assertAllClose(expected(one), def_function.function(f)(one))
f = _grad(f)
traced_f = def_function.function(_grad(traced_f))
def testIteratedGradientsNestedWithVariable(self):
def _grad(f):
def _grad_function():
with backprop.GradientTape() as tape:
primal_out = f()
g, = tape.gradient(primal_out, tape.watched_variables())
return g
return _grad_function
v = variables.Variable(2.)
@def_function.function
def _forward():
return math_ops.cos(v)
f = _forward
two = constant_op.constant(2.)
for expected in _COS_DERIVATIVES:
self.assertAllClose(expected(two), f())
self.assertAllClose(expected(two), def_function.function(f)())
f = _grad(f)
def testIteratedGradientsPersistent(self):
@def_function.function
def _forward(z):
return math_ops.cos(z)
f = _forward
with backprop.GradientTape(persistent=True) as tape:
start = constant_op.constant(1.)
tape.watch(start)
x = f(start)
for expected in _COS_DERIVATIVES:
self.assertAllClose(expected(start), x)
x = tape.gradient(x, start)
def testHigherOrderWithVariable(self):
v = variables.Variable(1.)
@def_function.function
def _forward():
return math_ops.cos(v)
f = _forward
with backprop.GradientTape(persistent=True) as tape:
x = f()
for expected in _COS_DERIVATIVES:
self.assertAllClose(expected(constant_op.constant(1.)), x)
x, = tape.gradient(x, tape.watched_variables())
def testGradientsChained(self):
@def_function.function
def _forward(z):
return math_ops.cos(z)
f = _forward
x = constant_op.constant(1.)
with backprop.GradientTape() as t:
t.watch(x)
y = f(x)
with backprop.GradientTape() as tt:
doutputs = constant_op.constant(2.)
tt.watch(doutputs)
g = t.gradient(y, x, doutputs)
self.assertAllClose(-2. * math_ops.sin(x), g)
gg = tt.gradient(g, doutputs)
# We're taking gradients with respect to doutputs, which is just a linear
# function of the gradient.
self.assertAllClose(-math_ops.sin(x), gg)
def testSymGradGatherNd(self): def testSymGradGatherNd(self):
with ops.Graph().as_default(), self.cached_session() as sess: with ops.Graph().as_default(), self.cached_session() as sess:

View File

@ -148,6 +148,13 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
PyObject* TFE_Py_TapeSetIsEmpty(); PyObject* TFE_Py_TapeSetIsEmpty();
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
// Like TFE_Py_TapeSetShouldRecord but with a ternary return:
// - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecord is false)
// - 1 if first-order gradients may be requested
// - 2 if higher-order gradients may be requested
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors);
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor); void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);

View File

@ -695,6 +695,14 @@ void SetOpAttrWithDefaults(
} }
} }
PyObject* GetPythonObjectFromInt(int num) {
#if PY_MAJOR_VERSION >= 3
return PyLong_FromLong(num);
#else
return PyInt_FromLong(num);
#endif
}
// Python subclass of Exception that is created on not ok Status. // Python subclass of Exception that is created on not ok Status.
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
@ -1500,33 +1508,51 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
return tensor_ids; return tensor_ids;
} }
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
// null. Returns true on success and false on a Python exception.
bool TensorShapesAndDtypes(PyObject* tensors,
std::vector<tensorflow::int64>* tensor_ids,
std::vector<tensorflow::DataType>* dtypes) {
tensorflow::Safe_PyObjectPtr seq(
PySequence_Fast(tensors, "expected a sequence"));
if (seq == nullptr) {
return false;
}
int len = PySequence_Fast_GET_SIZE(seq.get());
tensor_ids->reserve(len);
dtypes->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
tensor_ids->push_back(FastTensorId(item));
dtypes->push_back(FastTensorDtype(item));
}
return true;
}
bool TapeCouldPossiblyRecord(PyObject* tensors) {
if (tensors == Py_None) { if (tensors == Py_None) {
Py_RETURN_FALSE; return false;
} }
if (*ThreadTapeIsStopped()) { if (*ThreadTapeIsStopped()) {
Py_RETURN_FALSE; return false;
} }
if (!HasTape()) { if (!HasTape()) {
return false;
}
return true;
}
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors)) {
Py_RETURN_FALSE; Py_RETURN_FALSE;
} }
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return nullptr;
}
int len = PySequence_Fast_GET_SIZE(seq);
// TODO(apassos) consider not building a list and changing the API to check // TODO(apassos) consider not building a list and changing the API to check
// each tensor individually. // each tensor individually.
std::vector<tensorflow::int64> tensor_ids; std::vector<tensorflow::int64> tensor_ids;
std::vector<tensorflow::DataType> dtypes; std::vector<tensorflow::DataType> dtypes;
tensor_ids.reserve(len); if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
dtypes.reserve(len); return nullptr;
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
tensor_ids.push_back(FastTensorId(item));
dtypes.push_back(FastTensorDtype(item));
} }
Py_DECREF(seq);
auto tape_set = *GetTapeSet(); auto tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) { for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
@ -1543,6 +1569,53 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE; Py_RETURN_FALSE;
} }
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors)) {
return GetPythonObjectFromInt(0);
}
std::vector<tensorflow::int64> tensor_ids;
std::vector<tensorflow::DataType> dtypes;
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
return nullptr;
}
// If there is a persistent tape watching, or if there are multiple tapes
// watching, we'll return immediately indicating that higher-order tape
// gradients are possible.
bool some_tape_watching = false;
auto tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
if (tape->tape->IsPersistent() || some_tape_watching) {
// Either this is the second tape watching, or this tape is persistent:
// higher-order gradients are possible.
return GetPythonObjectFromInt(2);
}
some_tape_watching = true;
}
}
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
if (some_tape_watching) {
// This is the second tape watching: higher-order gradients are
// possible. Note that there's no equivalent of persistence for
// forward-mode.
return GetPythonObjectFromInt(2);
}
some_tape_watching = true;
}
}
if (some_tape_watching) {
// There's exactly one non-persistent tape. The user can request first-order
// gradients but won't be able to get higher-order tape gradients.
return GetPythonObjectFromInt(1);
} else {
// There are no tapes. The user can't request tape gradients.
return GetPythonObjectFromInt(0);
}
}
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
if (*ThreadTapeIsStopped()) { if (*ThreadTapeIsStopped()) {
return; return;
@ -1997,14 +2070,6 @@ PyObject* GetPythonObjectFromString(const char* s) {
#endif #endif
} }
PyObject* GetPythonObjectFromInt(int num) {
#if PY_MAJOR_VERSION >= 3
return PyLong_FromLong(num);
#else
return PyInt_FromLong(num);
#endif
}
bool CheckResourceVariable(PyObject* item) { bool CheckResourceVariable(PyObject* item) {
if (PyObject_TypeCheck(item, resource_variable_type)) { if (PyObject_TypeCheck(item, resource_variable_type)) {
tensorflow::Safe_PyObjectPtr handle( tensorflow::Safe_PyObjectPtr handle(

View File

@ -18,6 +18,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
@ -33,3 +35,25 @@ def get_zeros_dtype(t):
else: else:
return handle_data.shape_and_type[0].dtype return handle_data.shape_and_type[0].dtype
return t.dtype return t.dtype
def shape_and_dtype(t):
"""Return the shape and dtype for the default gradient for a Tensor."""
if t.dtype == dtypes.resource:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if (handle_data is None or not handle_data.is_set or
len(handle_data.shape_and_type) != 1):
return tensor_shape.TensorShape(None), dtypes.float32
else:
shape_and_type = handle_data.shape_and_type[0]
return (tensor_shape.TensorShape(shape_and_type.shape),
dtypes.as_dtype(shape_and_type.dtype))
return t.shape, t.dtype
def zeros_like(t):
"""Like array_ops.zeros_like, but respects resource handles."""
if t.dtype == dtypes.resource:
return array_ops.zeros(*shape_and_dtype(t))
else:
return array_ops.zeros_like(t)

View File

@ -74,6 +74,7 @@ limitations under the License.
%rename("%s") TFE_Py_TapeSetIsStopped; %rename("%s") TFE_Py_TapeSetIsStopped;
%rename("%s") TFE_Py_TapeSetIsEmpty; %rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecord; %rename("%s") TFE_Py_TapeSetShouldRecord;
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
%rename("%s") TFE_Py_TapeSetDeleteTrace; %rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation; %rename("%s") TFE_Py_TapeSetRecordOperation;
%rename("%s") TFE_Py_TapeGradient; %rename("%s") TFE_Py_TapeGradient;

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import function_deserialization
@ -174,6 +175,18 @@ class Loader(object):
for bound_input, internal_capture in zip( for bound_input, internal_capture in zip(
bound_inputs, concrete_function.inputs[-len(bound_inputs):]): bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
concrete_function.graph.captures[bound_input] = internal_capture concrete_function.graph.captures[bound_input] = internal_capture
if internal_capture.dtype == dtypes.resource:
if resource_variable_ops.is_resource_variable(bound_input):
try:
handle = bound_input.handle
except ValueError:
# For mirrored variables we'll copy handle data for components
# as they get captured.
pass
else:
custom_gradient.copy_handle_data(handle, internal_capture)
else:
custom_gradient.copy_handle_data(bound_input, internal_capture)
# Setting "captures" first means "capture" won't create a new # Setting "captures" first means "capture" won't create a new
# placeholder for this input. # placeholder for this input.
concrete_function.graph.capture(bound_input) concrete_function.graph.capture(bound_input)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import functools import functools
import weakref import weakref
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -172,14 +173,21 @@ class CapturableResourceDeleter(object):
def __init__(self, destroy_resource_fn=None): def __init__(self, destroy_resource_fn=None):
if destroy_resource_fn: if destroy_resource_fn:
self.destroy_resource = destroy_resource_fn self._destroy_resource = destroy_resource_fn
self._destruction_context = (
context.eager_mode if context.executing_eagerly()
else ops.get_default_graph().as_default)
else:
self._destroy_resource = None
def destroy_resource(self): def destroy_resource(self):
"""A function that destroys the resource.""" if self._destroy_resource:
pass return self._destroy_resource()
def __del__(self): def __del__(self):
self.destroy_resource() if self._destroy_resource:
with self._destruction_context():
self._destroy_resource()
class CapturableResource(base.Trackable): class CapturableResource(base.Trackable):