Accept output gradients of side outputs when calling functions
Fixes higher-order gradients of function calls When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested. When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request. PiperOrigin-RevId: 256268751
This commit is contained in:
parent
aabcdcbdff
commit
bd4feec252
@ -20,6 +20,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import enum # pylint: disable=g-bad-import-order
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import threading
|
import threading
|
||||||
@ -40,13 +41,16 @@ from tensorflow.python.framework import c_api_util
|
|||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import error_interpolation
|
from tensorflow.python.framework import error_interpolation
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import func_graph as func_graph_module
|
from tensorflow.python.framework import func_graph as func_graph_module
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import custom_gradient
|
from tensorflow.python.ops import custom_gradient
|
||||||
|
from tensorflow.python.ops import default_gradient
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gradients_util
|
from tensorflow.python.ops import gradients_util
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -390,7 +394,8 @@ class _EagerDefinedFunction(object):
|
|||||||
self._output_types = [o.type for o in self.signature.output_arg]
|
self._output_types = [o.type for o in self.signature.output_arg]
|
||||||
self._output_shapes = [o.shape for o in outputs]
|
self._output_shapes = [o.shape for o in outputs]
|
||||||
self._control_captures = graph.control_captures
|
self._control_captures = graph.control_captures
|
||||||
self._func_graph_outputs = outputs
|
# Shallow copy outputs since ConcreteFunction may mutate it.
|
||||||
|
self._func_graph_outputs = list(outputs)
|
||||||
self.grad_func_name = None
|
self.grad_func_name = None
|
||||||
self.python_grad_func = None
|
self.python_grad_func = None
|
||||||
self._c_func = c_api_util.ScopedTFFunction(fn)
|
self._c_func = c_api_util.ScopedTFFunction(fn)
|
||||||
@ -481,6 +486,13 @@ class _EagerDefinedFunction(object):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class _PossibleTapeGradientTypes(enum.Enum):
|
||||||
|
"""Represents the output of TFE_Py_TapeSetPossibleGradientTypes."""
|
||||||
|
NONE = 0
|
||||||
|
FIRST_ORDER = 1
|
||||||
|
HIGHER_ORDER = 2
|
||||||
|
|
||||||
|
|
||||||
class ConcreteFunction(object):
|
class ConcreteFunction(object):
|
||||||
"""Callable object encapsulating a function definition and its gradient.
|
"""Callable object encapsulating a function definition and its gradient.
|
||||||
|
|
||||||
@ -517,7 +529,27 @@ class ConcreteFunction(object):
|
|||||||
self._inference_function = _EagerDefinedFunction(
|
self._inference_function = _EagerDefinedFunction(
|
||||||
_inference_name(self._func_graph.name), self._func_graph,
|
_inference_name(self._func_graph.name), self._func_graph,
|
||||||
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
|
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
|
||||||
self._backward_graph_function = None
|
|
||||||
|
# When graph building without a tape active, symbolic gradients rely on
|
||||||
|
# regenerating the backward function for higher-order gradients (to account
|
||||||
|
# for new side outputs of the rewritten forward function call). Thus there
|
||||||
|
# is no fixed backward function for this case. However, when a tape is
|
||||||
|
# active (eager or graph building), we generate fixed backward and forward
|
||||||
|
# functions at forward function call time.
|
||||||
|
#
|
||||||
|
# This difference between the tape and non-tape cases is to avoid building
|
||||||
|
# unneeded backward functions while graph building (where we may or may not
|
||||||
|
# eventually need gradients).
|
||||||
|
self._tape_forward_function_first_order = None
|
||||||
|
self._tape_backward_function_first_order = None
|
||||||
|
self._tape_forward_function_higher_order = None
|
||||||
|
self._tape_backward_function_higher_order = None
|
||||||
|
|
||||||
|
# A map from the number of forward function outputs with accepted gradients
|
||||||
|
# to backward functions, used to cache non-tape backward function
|
||||||
|
# generation.
|
||||||
|
self._cached_graph_backprop_functions = {}
|
||||||
|
|
||||||
self._signature = signature
|
self._signature = signature
|
||||||
self._gradient_name = None
|
self._gradient_name = None
|
||||||
|
|
||||||
@ -673,12 +705,28 @@ class ConcreteFunction(object):
|
|||||||
"Tensor." % (self._func_graph.name, i, str(arg)))
|
"Tensor." % (self._func_graph.name, i, str(arg)))
|
||||||
args = tensor_inputs + captured_inputs
|
args = tensor_inputs + captured_inputs
|
||||||
|
|
||||||
if (tape.should_record(tensor_inputs) or
|
possible_gradient_type = _PossibleTapeGradientTypes(
|
||||||
tape.should_record(captured_inputs)):
|
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||||
if context.executing_eagerly():
|
if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER:
|
||||||
return self._eager_backprop_call(args)
|
# There is a single non-persistent tape active, so the user can only
|
||||||
else:
|
# request first-order gradients from a tape. We can spend less time graph
|
||||||
return self._backprop_call_with_delayed_rewrite(args)
|
# building since we know this.
|
||||||
|
#
|
||||||
|
# We may still end up computing higher-order gradients, but that'd be
|
||||||
|
# through `tf.gradients`, which can re-write the forward pass and so needs
|
||||||
|
# no preparation here.
|
||||||
|
forward_function, backward_function = (
|
||||||
|
self._tape_functions_for_first_order())
|
||||||
|
return self._tape_backprop_call(args, forward_function, backward_function)
|
||||||
|
elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER:
|
||||||
|
# Either there's a persistent tape watching, or there are multiple nested
|
||||||
|
# tapes. Either way, the user may request higher-order gradients. We'll
|
||||||
|
# spend a bit more time and make sure higher-order gradients are correct.
|
||||||
|
forward_function, backward_function = (
|
||||||
|
self._tape_functions_for_higher_order())
|
||||||
|
return self._tape_backprop_call(args, forward_function, backward_function)
|
||||||
|
# else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no
|
||||||
|
# tape is recording.
|
||||||
|
|
||||||
# Only need to override the gradient in graph mode and when we have outputs.
|
# Only need to override the gradient in graph mode and when we have outputs.
|
||||||
if context.executing_eagerly() or not self.outputs:
|
if context.executing_eagerly() or not self.outputs:
|
||||||
@ -708,30 +756,39 @@ class ConcreteFunction(object):
|
|||||||
|
|
||||||
def _grad_fn(self, op, *doutputs):
|
def _grad_fn(self, op, *doutputs):
|
||||||
"""Gradients of this function."""
|
"""Gradients of this function."""
|
||||||
if self._backward_graph_function is None:
|
backwards_function = self._graph_backprop_function(len(doutputs))
|
||||||
self._construct_backprop_function()
|
self._forward_function.add_to_graph(op.graph)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._forward_function.add_to_graph(op.graph)
|
|
||||||
num_inference_outputs = self._inference_function._num_outputs
|
|
||||||
|
|
||||||
# Rewrite an inference call op to be a forward call op
|
# Rewrite an inference call op to be a forward call op
|
||||||
if op.get_attr("f").name.encode() == self._inference_function.name:
|
op._set_func_attr("f", self._forward_function.name)
|
||||||
op._set_func_attr("f", self._forward_function.name)
|
op._set_type_list_attr("Tout", self._forward_function._output_types)
|
||||||
op._set_type_list_attr("Tout", self._forward_function._output_types)
|
op._add_outputs(
|
||||||
op._add_outputs(
|
self._forward_function._output_types[len(op.outputs):],
|
||||||
self._forward_function._output_types[num_inference_outputs:],
|
self._forward_function._output_shapes[len(op.outputs):])
|
||||||
self._forward_function._output_shapes[num_inference_outputs:])
|
for i in range(len(op.outputs)):
|
||||||
for i in range(num_inference_outputs, len(op.outputs)):
|
func_graph_output = self._forward_function._func_graph_outputs[i]
|
||||||
func_graph_output = self._forward_function._func_graph_outputs[i]
|
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
|
||||||
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
capture_mapping = dict(zip(self._func_graph.outputs, op.outputs))
|
||||||
|
remapped_captures = []
|
||||||
|
for capture in backwards_function.captured_inputs:
|
||||||
|
remapped_captures.append(capture_mapping.get(capture, capture))
|
||||||
|
|
||||||
|
# Replace Nones with zeros since we're calling a graph function which
|
||||||
|
# expects numeric inputs.
|
||||||
|
cleaned_doutputs = []
|
||||||
|
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
|
||||||
|
if gradients_util.IsTrainable(placeholder):
|
||||||
|
if doutput is not None:
|
||||||
|
cleaned_doutputs.append(doutput)
|
||||||
|
else:
|
||||||
|
cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
|
||||||
|
|
||||||
# Compute the gradients using the side outputs
|
# Compute the gradients using the side outputs
|
||||||
side_outputs = op.outputs[num_inference_outputs:]
|
return backwards_function._call_flat( # pylint: disable=protected-access
|
||||||
args = list(doutputs[:num_inference_outputs]) + list(side_outputs)
|
cleaned_doutputs, remapped_captures)
|
||||||
return self._backward_graph_function._call_flat( # pylint: disable=protected-access
|
|
||||||
(a for a in args if a is not None),
|
|
||||||
self._backward_graph_function.captured_inputs)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -820,16 +877,190 @@ class ConcreteFunction(object):
|
|||||||
# 2. Otherwise, defun will create two functions, one for forward pass,
|
# 2. Otherwise, defun will create two functions, one for forward pass,
|
||||||
# and the backward pass will be created via tape.
|
# and the backward pass will be created via tape.
|
||||||
# When registering the function, we register both cases.
|
# When registering the function, we register both cases.
|
||||||
if self._backward_graph_function is None:
|
backward_function = self._graph_backprop_function()._inference_function
|
||||||
self._construct_backprop_function()
|
|
||||||
forward_function = self._forward_function
|
forward_function = self._forward_function
|
||||||
backward_function = self._backward_graph_function._inference_function
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
forward_function.add_to_graph(g)
|
forward_function.add_to_graph(g)
|
||||||
backward_function.add_to_graph(g)
|
backward_function.add_to_graph(g)
|
||||||
|
|
||||||
def _construct_backprop_function(self):
|
def _graph_backprop_function(self, num_doutputs=None):
|
||||||
"""Constructs the backprop function object for this function."""
|
"""A possibly-cached backprop function."""
|
||||||
|
backward_function = self._cached_graph_backprop_functions.get(
|
||||||
|
num_doutputs, None)
|
||||||
|
if backward_function is not None:
|
||||||
|
return backward_function
|
||||||
|
backward_function = self._construct_graph_backprop_function(num_doutputs)
|
||||||
|
self._cached_graph_backprop_functions[num_doutputs] = backward_function
|
||||||
|
return backward_function
|
||||||
|
|
||||||
|
def _construct_graph_backprop_function(self, num_doutputs=None):
|
||||||
|
"""Constructs a backprop function object for this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_doutputs: The constructed backprop function will take output gradients
|
||||||
|
for the first `num_doutputs` outputs of the forward function. Defaults
|
||||||
|
to the number of outputs for the inference function, but when
|
||||||
|
higher-order gradients are computed this will increase to include side
|
||||||
|
outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A backward function taking `num_doutputs` arguments and returning
|
||||||
|
gradients with respect to inputs of the forward function.
|
||||||
|
|
||||||
|
self._forward_function is re-generated to account for new side outputs, if
|
||||||
|
any extra were required when building the backward pass.
|
||||||
|
"""
|
||||||
|
if num_doutputs is None:
|
||||||
|
num_doutputs = len(self._inference_function.signature.output_arg)
|
||||||
|
trainable_outputs = [
|
||||||
|
output for output in self._func_graph.outputs[:num_doutputs]
|
||||||
|
if gradients_util.IsTrainable(output)]
|
||||||
|
|
||||||
|
signature = []
|
||||||
|
for t in trainable_outputs:
|
||||||
|
signature.append(
|
||||||
|
tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
|
||||||
|
|
||||||
|
def _backprop_function(*grad_ys):
|
||||||
|
return gradients_util._GradientsHelper( # pylint: disable=protected-access
|
||||||
|
trainable_outputs,
|
||||||
|
self._func_graph.inputs,
|
||||||
|
grad_ys=grad_ys,
|
||||||
|
src_graph=self._func_graph)
|
||||||
|
|
||||||
|
with self._func_graph.as_default():
|
||||||
|
backwards_graph = func_graph_module.FuncGraph(
|
||||||
|
_backward_name(self._func_graph.name))
|
||||||
|
func_graph_module.func_graph_from_py_func(
|
||||||
|
name=backwards_graph.name,
|
||||||
|
python_func=_backprop_function,
|
||||||
|
args=[], kwargs={},
|
||||||
|
signature=signature,
|
||||||
|
func_graph=backwards_graph)
|
||||||
|
backwards_graph_captures = list(backwards_graph.captures.keys())
|
||||||
|
captures_from_forward = [
|
||||||
|
c for c in backwards_graph_captures if
|
||||||
|
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||||
|
|
||||||
|
forward_function_name = _forward_name(self._func_graph.name)
|
||||||
|
|
||||||
|
existing_outputs = set(self._func_graph.outputs)
|
||||||
|
for capture in captures_from_forward:
|
||||||
|
if capture not in existing_outputs:
|
||||||
|
existing_outputs.add(capture)
|
||||||
|
self._func_graph.outputs.append(capture)
|
||||||
|
backward_function_attr = _parse_func_attrs(
|
||||||
|
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||||
|
backward_function_attr.update(self._attrs)
|
||||||
|
|
||||||
|
backward_function = ConcreteFunction(
|
||||||
|
backwards_graph, attrs=backward_function_attr)
|
||||||
|
forward_function_attr = _parse_func_attrs({
|
||||||
|
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||||
|
backward_function._inference_function.name}) # pylint: disable=protected-access
|
||||||
|
forward_function_attr.update(self._attrs)
|
||||||
|
|
||||||
|
self._forward_function = _EagerDefinedFunction(
|
||||||
|
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||||
|
self._func_graph.outputs, forward_function_attr)
|
||||||
|
return backward_function
|
||||||
|
|
||||||
|
def _tape_functions_for_first_order(self):
|
||||||
|
"""Shortcut for when only first-order gradients are required.
|
||||||
|
|
||||||
|
The returned backward function does not accept gradients with respect to
|
||||||
|
side output of forward_function. This is fine as long as the user can't
|
||||||
|
possibly request second order tape gradients, as when they've used a single
|
||||||
|
non-persistent GradientTape. Since we don't need the backward function to
|
||||||
|
take gradients with respect to side outputs, we can skip some potentially
|
||||||
|
slow graph building.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (forward_function, backward_function):
|
||||||
|
forward_function: Takes the same inputs as the inference function, but
|
||||||
|
returns side outputs used by backward_function in addition to the
|
||||||
|
inference function's outputs.
|
||||||
|
backward_function: Takes side outputs from forward_function and
|
||||||
|
gradients with respect to the "real" outputs of forward_function and
|
||||||
|
returns gradients with respect to the inputs.
|
||||||
|
"""
|
||||||
|
if self._tape_forward_function_first_order is not None:
|
||||||
|
return (self._tape_forward_function_first_order,
|
||||||
|
self._tape_backward_function_first_order)
|
||||||
|
outputs = self._func_graph.outputs[
|
||||||
|
:len(self._inference_function.signature.output_arg)]
|
||||||
|
forward_function, backward_function = (
|
||||||
|
self._tape_forward_and_backward_functions(outputs))
|
||||||
|
self._tape_forward_function_first_order = forward_function
|
||||||
|
self._tape_backward_function_first_order = backward_function
|
||||||
|
return forward_function, backward_function
|
||||||
|
|
||||||
|
# TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
|
||||||
|
# generalizing if so.
|
||||||
|
def _tape_functions_for_higher_order(self):
|
||||||
|
"""Forward and backward functions suitable for higher-order gradients.
|
||||||
|
|
||||||
|
Unlike `_tape_functions_for_first_order`, the backward function built by
|
||||||
|
this method accepts gradients for all of the outputs of the returned forward
|
||||||
|
function, including side outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (forward_function, backward_function):
|
||||||
|
forward_function: Takes the same inputs as the inference function, but
|
||||||
|
returns side outputs used by backward_function in addition to the
|
||||||
|
inference function's outputs.
|
||||||
|
backward_function: Takes side outputs from forward_function and
|
||||||
|
gradients with respect to all of its outputs, real and side. Returns
|
||||||
|
gradients with respect to the inputs.
|
||||||
|
"""
|
||||||
|
if self._tape_forward_function_higher_order is not None:
|
||||||
|
return (self._tape_forward_function_higher_order,
|
||||||
|
self._tape_backward_function_higher_order)
|
||||||
|
outputs = []
|
||||||
|
# First we need to figure out how many side outputs from the forward pass
|
||||||
|
# will be required. We do this in a temporary graph to avoid actually
|
||||||
|
# running multiple copies of the backward pass (one per _GradientsHelper
|
||||||
|
# call).
|
||||||
|
#
|
||||||
|
# While computing gradients, the backward function captures Tensors from
|
||||||
|
# the forward function. We add these as side outputs of the original
|
||||||
|
# function. However, we then need to accept output gradients with respect
|
||||||
|
# to these side outputs for higher order gradients to work. Thus we loop
|
||||||
|
# until the number of outputs of the function stabilizes. Note that this
|
||||||
|
# is only required for tape gradients, where we need to declare in advance
|
||||||
|
# all of the forward op's outputs: symbolic gradients with tf.gradients
|
||||||
|
# instead rely on regenerating backward functions when higher-order
|
||||||
|
# gradients are requested.
|
||||||
|
while len(outputs) < len(self._func_graph.outputs):
|
||||||
|
new_outputs = self._func_graph.outputs[len(outputs):]
|
||||||
|
outputs = list(self._func_graph.outputs)
|
||||||
|
self._tape_forward_and_backward_functions(new_outputs)
|
||||||
|
forward_function, backward_function = (
|
||||||
|
self._tape_forward_and_backward_functions(outputs))
|
||||||
|
if len(self._func_graph.outputs) != len(outputs):
|
||||||
|
raise AssertionError(
|
||||||
|
("Unexpectedly added new outputs to the forward function when "
|
||||||
|
"building the backward function: {}").format(
|
||||||
|
self._func_graph.outputs[len(outputs):]))
|
||||||
|
self._tape_forward_function_higher_order = forward_function
|
||||||
|
self._tape_backward_function_higher_order = backward_function
|
||||||
|
return forward_function, backward_function
|
||||||
|
|
||||||
|
def _tape_forward_and_backward_functions(self, outputs):
|
||||||
|
"""Constructs tape forward and back functions for `outputs`."""
|
||||||
|
# First figure out which of `outputs` are trainable. We'll accept gradients
|
||||||
|
# for each of these in the backward function.
|
||||||
|
handles_to_variables = {self._func_graph.captures[v.handle]: v
|
||||||
|
for v in self._func_graph.variables
|
||||||
|
if v.handle in self._func_graph.captures}
|
||||||
|
trainable_outputs = []
|
||||||
|
for output in outputs:
|
||||||
|
if gradients_util.IsTrainable(output):
|
||||||
|
# Swap in the Variable object for resource handles if we can so
|
||||||
|
# sparse gradients work.
|
||||||
|
output = handles_to_variables.get(output, output)
|
||||||
|
trainable_outputs.append(output)
|
||||||
|
|
||||||
backwards_graph = func_graph_module.FuncGraph(
|
backwards_graph = func_graph_module.FuncGraph(
|
||||||
_backward_name(self._func_graph.name))
|
_backward_name(self._func_graph.name))
|
||||||
# Keep track of the forward graph so that if the backwards graph
|
# Keep track of the forward graph so that if the backwards graph
|
||||||
@ -837,73 +1068,79 @@ class ConcreteFunction(object):
|
|||||||
# the forward graph. This is an edge case that can only happen with
|
# the forward graph. This is an edge case that can only happen with
|
||||||
# tf.custom_gradient.
|
# tf.custom_gradient.
|
||||||
backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access
|
backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access
|
||||||
forward_function_name = _forward_name(self._func_graph.name)
|
|
||||||
outputs = [x for x in self._func_graph.outputs
|
|
||||||
if gradients_util.IsTrainable(x)]
|
|
||||||
with backwards_graph.as_default():
|
with backwards_graph.as_default():
|
||||||
gradients_wrt_outputs = [
|
gradients_wrt_outputs = []
|
||||||
graph_placeholder(x.dtype, x.shape) for x in outputs
|
for output in trainable_outputs:
|
||||||
]
|
gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
|
||||||
|
output)
|
||||||
|
gradients_wrt_outputs.append(
|
||||||
|
graph_placeholder(gradient_dtype, gradient_shape))
|
||||||
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
|
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
|
||||||
outputs,
|
trainable_outputs,
|
||||||
self._func_graph.inputs,
|
self._func_graph.inputs,
|
||||||
grad_ys=gradients_wrt_outputs,
|
grad_ys=gradients_wrt_outputs,
|
||||||
src_graph=self._func_graph)
|
src_graph=self._func_graph)
|
||||||
|
|
||||||
backwards_graph_captures = list(backwards_graph.captures.keys())
|
captures_from_forward = [
|
||||||
|
c for c in backwards_graph.captures.keys() if
|
||||||
|
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||||
|
existing_outputs = set(self._func_graph.outputs)
|
||||||
|
for capture in captures_from_forward:
|
||||||
|
if capture not in existing_outputs:
|
||||||
|
existing_outputs.add(capture)
|
||||||
|
self._func_graph.outputs.append(capture)
|
||||||
|
|
||||||
|
forward_function_name = _forward_name(self._func_graph.name)
|
||||||
backward_function_attr = _parse_func_attrs(
|
backward_function_attr = _parse_func_attrs(
|
||||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||||
backward_function_attr.update(self._attrs)
|
backward_function_attr.update(self._attrs)
|
||||||
|
|
||||||
# The ordering of `backwards_graph.inputs` is important: inputs of
|
# The ordering of `backwards_graph.inputs` is important: inputs of
|
||||||
# `self._backward_graph_function` correspond to outputs of
|
# `backward_function` correspond to outputs (including
|
||||||
# `self._forward_function`.
|
# side outputs) of `self._tape_forward_function`.
|
||||||
backwards_graph.inputs = gradients_wrt_outputs + list(
|
backwards_graph.inputs = (
|
||||||
backwards_graph.captures.values())
|
gradients_wrt_outputs + list(backwards_graph.captures.values()))
|
||||||
# Clear captures, since we pass them in as inputs.
|
|
||||||
backwards_graph.captures = {}
|
|
||||||
backwards_graph.outputs.extend(
|
backwards_graph.outputs.extend(
|
||||||
grad
|
grad
|
||||||
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
||||||
if grad is not None)
|
if grad is not None)
|
||||||
backwards_graph.structured_outputs = gradients_wrt_inputs
|
backwards_graph.structured_outputs = gradients_wrt_inputs
|
||||||
self._backward_graph_function = ConcreteFunction(
|
backward_function = ConcreteFunction(
|
||||||
backwards_graph, attrs=backward_function_attr)
|
backwards_graph, attrs=backward_function_attr)
|
||||||
|
|
||||||
forward_function_attr = _parse_func_attrs({
|
forward_function_attr = _parse_func_attrs({
|
||||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||||
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
|
backward_function._inference_function.name}) # pylint: disable=protected-access
|
||||||
forward_function_attr.update(self._attrs)
|
forward_function_attr.update(self._attrs)
|
||||||
self._forward_function = _EagerDefinedFunction(
|
|
||||||
|
forward_function = _EagerDefinedFunction(
|
||||||
forward_function_name, self._func_graph, self._func_graph.inputs,
|
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||||
self._func_graph.outputs + backwards_graph_captures,
|
self._func_graph.outputs,
|
||||||
forward_function_attr)
|
forward_function_attr)
|
||||||
|
return forward_function, backward_function
|
||||||
|
|
||||||
def _eager_backprop_call(self, args):
|
def _tape_backprop_call(self, args, forward_function, backward_function):
|
||||||
"""Calls the forward function and records the result on a tape.
|
"""Calls the forward function and records the result on a tape.
|
||||||
|
|
||||||
This method fully constructs the forward and backward functions before
|
|
||||||
calling the function and recording them on the tape.
|
|
||||||
|
|
||||||
(Only records results on a tape if the function has outputs).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: All inputs to the function, including resolved captured inputs
|
args: All inputs to the function, including resolved captured inputs
|
||||||
|
forward_function: The forward pass, outputting both user-specified and
|
||||||
|
side outputs.
|
||||||
|
backward_function: Computes gradients for inputs of forward_function given
|
||||||
|
output gradients for the first `N` of forward_function's outputs, not
|
||||||
|
necessarily all of them. See `_tape_functions_for_first_order` and
|
||||||
|
`_tape_functions_for_higher_order`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The call output.
|
The call output.
|
||||||
"""
|
"""
|
||||||
if self._backward_graph_function is None:
|
|
||||||
self._construct_backprop_function()
|
|
||||||
|
|
||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
|
|
||||||
self._register_gradient()
|
self._register_gradient()
|
||||||
with ops.get_default_graph().gradient_override_map(
|
with ops.get_default_graph().gradient_override_map(
|
||||||
{"PartitionedCall": self._gradient_name,
|
{"PartitionedCall": self._gradient_name,
|
||||||
"StatefulPartitionedCall": self._gradient_name}):
|
"StatefulPartitionedCall": self._gradient_name}):
|
||||||
outputs = self._forward_function.call(ctx, args)
|
outputs = forward_function.call(ctx, args)
|
||||||
|
|
||||||
if isinstance(outputs, ops.Operation) or outputs is None:
|
if isinstance(outputs, ops.Operation) or outputs is None:
|
||||||
return outputs
|
return outputs
|
||||||
@ -912,19 +1149,55 @@ class ConcreteFunction(object):
|
|||||||
# `side_outputs` are the intermediate Tensors that were added as outputs to
|
# `side_outputs` are the intermediate Tensors that were added as outputs to
|
||||||
# the forward graph function so that we can compute its gradient.
|
# the forward graph function so that we can compute its gradient.
|
||||||
real_outputs = outputs[:self._num_outputs]
|
real_outputs = outputs[:self._num_outputs]
|
||||||
skip_positions = [i for i, t in enumerate(real_outputs)
|
|
||||||
if not gradients_util.IsTrainable(t)]
|
|
||||||
side_outputs = outputs[self._num_outputs:]
|
|
||||||
|
|
||||||
def backward_function(*args):
|
capture_mapping = dict(zip(self._func_graph.outputs, outputs))
|
||||||
args = [a for i, a in enumerate(args)
|
remapped_captures = [
|
||||||
if a is not None and i not in skip_positions]
|
capture_mapping.get(capture, capture)
|
||||||
return self._backward_graph_function._call_flat( # pylint: disable=protected-access
|
for capture in backward_function.captured_inputs]
|
||||||
list(args) + side_outputs,
|
# We may need to use zeros_like to get a zero for variant Tensors with
|
||||||
self._backward_graph_function.captured_inputs)
|
# unconnected gradients. We do that in advance so we don't have to hold on
|
||||||
|
# to the outputs themselves, which may not be needed otherwise.
|
||||||
|
variant_zeros_like = {}
|
||||||
|
backward_function_inputs = (
|
||||||
|
len(backward_function.inputs) - len(backward_function.captured_inputs))
|
||||||
|
recorded_outputs = []
|
||||||
|
trainable_recorded_outputs = 0
|
||||||
|
skip_positions = []
|
||||||
|
for output_index, output in enumerate(outputs):
|
||||||
|
if trainable_recorded_outputs < backward_function_inputs:
|
||||||
|
recorded_outputs.append(output)
|
||||||
|
if gradients_util.IsTrainable(output):
|
||||||
|
trainable_recorded_outputs += 1
|
||||||
|
else:
|
||||||
|
skip_positions.append(output_index)
|
||||||
|
if output.dtype == dtypes.variant:
|
||||||
|
variant_zeros_like[output_index] = default_gradient.zeros_like(output)
|
||||||
|
|
||||||
tape.record_operation(self._forward_function.signature.name, real_outputs,
|
def _backward_function_wrapper(*args):
|
||||||
args, backward_function)
|
"""Process output gradients and call the backward function."""
|
||||||
|
processed_args = []
|
||||||
|
input_index = 0
|
||||||
|
for output_index, arg in enumerate(args):
|
||||||
|
if output_index in skip_positions:
|
||||||
|
continue
|
||||||
|
if arg is None:
|
||||||
|
# We're calling a (non-polymorphic) ConcreteFunction, so we need to
|
||||||
|
# have a Tensor value for each Tensor we thought would be trainable
|
||||||
|
# based on its dtype, even if it ended up being unconnected.
|
||||||
|
input_placeholder = backward_function.inputs[
|
||||||
|
input_index]
|
||||||
|
if input_placeholder.dtype == dtypes.variant:
|
||||||
|
arg = variant_zeros_like[output_index]
|
||||||
|
else:
|
||||||
|
arg = array_ops.zeros(
|
||||||
|
*default_gradient.shape_and_dtype(input_placeholder))
|
||||||
|
processed_args.append(arg)
|
||||||
|
input_index += 1
|
||||||
|
return backward_function._call_flat( # pylint: disable=protected-access
|
||||||
|
processed_args, remapped_captures)
|
||||||
|
|
||||||
|
tape.record_operation(forward_function.signature.name,
|
||||||
|
recorded_outputs, args, _backward_function_wrapper)
|
||||||
return self._build_call_outputs(real_outputs)
|
return self._build_call_outputs(real_outputs)
|
||||||
|
|
||||||
def _backprop_call_with_delayed_rewrite(self, args):
|
def _backprop_call_with_delayed_rewrite(self, args):
|
||||||
|
@ -40,6 +40,13 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
_COS_DERIVATIVES = [math_ops.cos,
|
||||||
|
lambda x: -math_ops.sin(x),
|
||||||
|
lambda x: -math_ops.cos(x),
|
||||||
|
math_ops.sin,
|
||||||
|
math_ops.cos]
|
||||||
|
|
||||||
|
|
||||||
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testGraphModeWithGradients(self):
|
def testGraphModeWithGradients(self):
|
||||||
@ -68,6 +75,145 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(grads.eval(), 2.0)
|
self.assertAllEqual(grads.eval(), 2.0)
|
||||||
self.assertEqual(grads.shape, v.shape)
|
self.assertEqual(grads.shape, v.shape)
|
||||||
|
|
||||||
|
def testSymbolicHigherOrder(self):
|
||||||
|
@def_function.function
|
||||||
|
def f(x, order):
|
||||||
|
y = def_function.function(lambda: math_ops.cos(x))()
|
||||||
|
for _ in range(order):
|
||||||
|
y, = gradients_impl.gradients(y, [x])
|
||||||
|
return y
|
||||||
|
for order, expected in enumerate(_COS_DERIVATIVES):
|
||||||
|
self.assertAllClose(
|
||||||
|
expected(constant_op.constant(1.)),
|
||||||
|
f(constant_op.constant(1.), order))
|
||||||
|
|
||||||
|
@parameterized.parameters([dict(persistent=True),
|
||||||
|
dict(persistent=False)])
|
||||||
|
def testSymbolicHigherOrderUnderTape(self, persistent):
|
||||||
|
@def_function.function
|
||||||
|
def f(x, order):
|
||||||
|
with backprop.GradientTape(persistent=persistent) as tape:
|
||||||
|
tape.watch(x)
|
||||||
|
# Note that having a tape active, even if we don't use it, forces us
|
||||||
|
# down a different function call path. Symbolic gradients should work
|
||||||
|
# here too; correctness of tape gradients are tested elsewhere.
|
||||||
|
y = def_function.function(lambda: math_ops.cos(x))()
|
||||||
|
tape_dy = tape.gradient(y, x)
|
||||||
|
for _ in range(order):
|
||||||
|
y, = gradients_impl.gradients(y, [x])
|
||||||
|
if order > 0:
|
||||||
|
y1 = tape_dy
|
||||||
|
for _ in range(order - 1):
|
||||||
|
y1, = gradients_impl.gradients(y1, [x])
|
||||||
|
else:
|
||||||
|
y1 = y
|
||||||
|
return y, y1
|
||||||
|
for order, expected_f in enumerate(_COS_DERIVATIVES):
|
||||||
|
expected = self.evaluate(expected_f(constant_op.constant(1.)))
|
||||||
|
self.assertAllClose(
|
||||||
|
(expected, expected),
|
||||||
|
f(constant_op.constant(1.), order))
|
||||||
|
|
||||||
|
def testIteratedGradientsNested(self):
|
||||||
|
|
||||||
|
def _grad(f):
|
||||||
|
def _grad_function(primal):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(primal)
|
||||||
|
primal_out = f(primal)
|
||||||
|
return tape.gradient(primal_out, primal)
|
||||||
|
return _grad_function
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _forward(x):
|
||||||
|
return math_ops.cos(x)
|
||||||
|
|
||||||
|
f = _forward
|
||||||
|
traced_f = def_function.function(f)
|
||||||
|
one = constant_op.constant(1.)
|
||||||
|
for expected in _COS_DERIVATIVES:
|
||||||
|
self.assertAllClose(expected(one), f(one))
|
||||||
|
self.assertAllClose(expected(one), traced_f(one))
|
||||||
|
self.assertAllClose(expected(one), def_function.function(f)(one))
|
||||||
|
f = _grad(f)
|
||||||
|
traced_f = def_function.function(_grad(traced_f))
|
||||||
|
|
||||||
|
def testIteratedGradientsNestedWithVariable(self):
|
||||||
|
|
||||||
|
def _grad(f):
|
||||||
|
def _grad_function():
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
primal_out = f()
|
||||||
|
g, = tape.gradient(primal_out, tape.watched_variables())
|
||||||
|
return g
|
||||||
|
return _grad_function
|
||||||
|
|
||||||
|
v = variables.Variable(2.)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _forward():
|
||||||
|
return math_ops.cos(v)
|
||||||
|
|
||||||
|
f = _forward
|
||||||
|
|
||||||
|
two = constant_op.constant(2.)
|
||||||
|
|
||||||
|
for expected in _COS_DERIVATIVES:
|
||||||
|
self.assertAllClose(expected(two), f())
|
||||||
|
self.assertAllClose(expected(two), def_function.function(f)())
|
||||||
|
f = _grad(f)
|
||||||
|
|
||||||
|
def testIteratedGradientsPersistent(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _forward(z):
|
||||||
|
return math_ops.cos(z)
|
||||||
|
|
||||||
|
f = _forward
|
||||||
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
|
start = constant_op.constant(1.)
|
||||||
|
tape.watch(start)
|
||||||
|
x = f(start)
|
||||||
|
for expected in _COS_DERIVATIVES:
|
||||||
|
self.assertAllClose(expected(start), x)
|
||||||
|
x = tape.gradient(x, start)
|
||||||
|
|
||||||
|
def testHigherOrderWithVariable(self):
|
||||||
|
|
||||||
|
v = variables.Variable(1.)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _forward():
|
||||||
|
return math_ops.cos(v)
|
||||||
|
|
||||||
|
f = _forward
|
||||||
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
|
x = f()
|
||||||
|
for expected in _COS_DERIVATIVES:
|
||||||
|
self.assertAllClose(expected(constant_op.constant(1.)), x)
|
||||||
|
x, = tape.gradient(x, tape.watched_variables())
|
||||||
|
|
||||||
|
def testGradientsChained(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _forward(z):
|
||||||
|
return math_ops.cos(z)
|
||||||
|
|
||||||
|
f = _forward
|
||||||
|
x = constant_op.constant(1.)
|
||||||
|
with backprop.GradientTape() as t:
|
||||||
|
t.watch(x)
|
||||||
|
y = f(x)
|
||||||
|
with backprop.GradientTape() as tt:
|
||||||
|
doutputs = constant_op.constant(2.)
|
||||||
|
tt.watch(doutputs)
|
||||||
|
g = t.gradient(y, x, doutputs)
|
||||||
|
self.assertAllClose(-2. * math_ops.sin(x), g)
|
||||||
|
gg = tt.gradient(g, doutputs)
|
||||||
|
# We're taking gradients with respect to doutputs, which is just a linear
|
||||||
|
# function of the gradient.
|
||||||
|
self.assertAllClose(-math_ops.sin(x), gg)
|
||||||
|
|
||||||
def testSymGradGatherNd(self):
|
def testSymGradGatherNd(self):
|
||||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||||
|
|
||||||
|
@ -148,6 +148,13 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
|
|||||||
PyObject* TFE_Py_TapeSetIsEmpty();
|
PyObject* TFE_Py_TapeSetIsEmpty();
|
||||||
|
|
||||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
|
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
|
||||||
|
|
||||||
|
// Like TFE_Py_TapeSetShouldRecord but with a ternary return:
|
||||||
|
// - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecord is false)
|
||||||
|
// - 1 if first-order gradients may be requested
|
||||||
|
// - 2 if higher-order gradients may be requested
|
||||||
|
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors);
|
||||||
|
|
||||||
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
|
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
|
||||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
|
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
|
||||||
|
|
||||||
|
@ -695,6 +695,14 @@ void SetOpAttrWithDefaults(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* GetPythonObjectFromInt(int num) {
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
return PyLong_FromLong(num);
|
||||||
|
#else
|
||||||
|
return PyInt_FromLong(num);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Python subclass of Exception that is created on not ok Status.
|
// Python subclass of Exception that is created on not ok Status.
|
||||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||||
@ -1500,33 +1508,51 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
|
|||||||
return tensor_ids;
|
return tensor_ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
|
||||||
|
// null. Returns true on success and false on a Python exception.
|
||||||
|
bool TensorShapesAndDtypes(PyObject* tensors,
|
||||||
|
std::vector<tensorflow::int64>* tensor_ids,
|
||||||
|
std::vector<tensorflow::DataType>* dtypes) {
|
||||||
|
tensorflow::Safe_PyObjectPtr seq(
|
||||||
|
PySequence_Fast(tensors, "expected a sequence"));
|
||||||
|
if (seq == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int len = PySequence_Fast_GET_SIZE(seq.get());
|
||||||
|
tensor_ids->reserve(len);
|
||||||
|
dtypes->reserve(len);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
|
||||||
|
tensor_ids->push_back(FastTensorId(item));
|
||||||
|
dtypes->push_back(FastTensorDtype(item));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TapeCouldPossiblyRecord(PyObject* tensors) {
|
||||||
if (tensors == Py_None) {
|
if (tensors == Py_None) {
|
||||||
Py_RETURN_FALSE;
|
return false;
|
||||||
}
|
}
|
||||||
if (*ThreadTapeIsStopped()) {
|
if (*ThreadTapeIsStopped()) {
|
||||||
Py_RETURN_FALSE;
|
return false;
|
||||||
}
|
}
|
||||||
if (!HasTape()) {
|
if (!HasTape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||||
|
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||||
Py_RETURN_FALSE;
|
Py_RETURN_FALSE;
|
||||||
}
|
}
|
||||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
|
||||||
if (seq == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
int len = PySequence_Fast_GET_SIZE(seq);
|
|
||||||
// TODO(apassos) consider not building a list and changing the API to check
|
// TODO(apassos) consider not building a list and changing the API to check
|
||||||
// each tensor individually.
|
// each tensor individually.
|
||||||
std::vector<tensorflow::int64> tensor_ids;
|
std::vector<tensorflow::int64> tensor_ids;
|
||||||
std::vector<tensorflow::DataType> dtypes;
|
std::vector<tensorflow::DataType> dtypes;
|
||||||
tensor_ids.reserve(len);
|
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
||||||
dtypes.reserve(len);
|
return nullptr;
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
|
||||||
tensor_ids.push_back(FastTensorId(item));
|
|
||||||
dtypes.push_back(FastTensorDtype(item));
|
|
||||||
}
|
}
|
||||||
Py_DECREF(seq);
|
|
||||||
auto tape_set = *GetTapeSet();
|
auto tape_set = *GetTapeSet();
|
||||||
for (TFE_Py_Tape* tape : tape_set) {
|
for (TFE_Py_Tape* tape : tape_set) {
|
||||||
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||||
@ -1543,6 +1569,53 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
|||||||
Py_RETURN_FALSE;
|
Py_RETURN_FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
||||||
|
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||||
|
return GetPythonObjectFromInt(0);
|
||||||
|
}
|
||||||
|
std::vector<tensorflow::int64> tensor_ids;
|
||||||
|
std::vector<tensorflow::DataType> dtypes;
|
||||||
|
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is a persistent tape watching, or if there are multiple tapes
|
||||||
|
// watching, we'll return immediately indicating that higher-order tape
|
||||||
|
// gradients are possible.
|
||||||
|
bool some_tape_watching = false;
|
||||||
|
auto tape_set = *GetTapeSet();
|
||||||
|
for (TFE_Py_Tape* tape : tape_set) {
|
||||||
|
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||||
|
if (tape->tape->IsPersistent() || some_tape_watching) {
|
||||||
|
// Either this is the second tape watching, or this tape is persistent:
|
||||||
|
// higher-order gradients are possible.
|
||||||
|
return GetPythonObjectFromInt(2);
|
||||||
|
}
|
||||||
|
some_tape_watching = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto forward_accumulators = *GetAccumulatorSet();
|
||||||
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||||
|
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
|
||||||
|
if (some_tape_watching) {
|
||||||
|
// This is the second tape watching: higher-order gradients are
|
||||||
|
// possible. Note that there's no equivalent of persistence for
|
||||||
|
// forward-mode.
|
||||||
|
return GetPythonObjectFromInt(2);
|
||||||
|
}
|
||||||
|
some_tape_watching = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (some_tape_watching) {
|
||||||
|
// There's exactly one non-persistent tape. The user can request first-order
|
||||||
|
// gradients but won't be able to get higher-order tape gradients.
|
||||||
|
return GetPythonObjectFromInt(1);
|
||||||
|
} else {
|
||||||
|
// There are no tapes. The user can't request tape gradients.
|
||||||
|
return GetPythonObjectFromInt(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
|
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
|
||||||
if (*ThreadTapeIsStopped()) {
|
if (*ThreadTapeIsStopped()) {
|
||||||
return;
|
return;
|
||||||
@ -1997,14 +2070,6 @@ PyObject* GetPythonObjectFromString(const char* s) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* GetPythonObjectFromInt(int num) {
|
|
||||||
#if PY_MAJOR_VERSION >= 3
|
|
||||||
return PyLong_FromLong(num);
|
|
||||||
#else
|
|
||||||
return PyInt_FromLong(num);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CheckResourceVariable(PyObject* item) {
|
bool CheckResourceVariable(PyObject* item) {
|
||||||
if (PyObject_TypeCheck(item, resource_variable_type)) {
|
if (PyObject_TypeCheck(item, resource_variable_type)) {
|
||||||
tensorflow::Safe_PyObjectPtr handle(
|
tensorflow::Safe_PyObjectPtr handle(
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
|
||||||
|
|
||||||
@ -33,3 +35,25 @@ def get_zeros_dtype(t):
|
|||||||
else:
|
else:
|
||||||
return handle_data.shape_and_type[0].dtype
|
return handle_data.shape_and_type[0].dtype
|
||||||
return t.dtype
|
return t.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def shape_and_dtype(t):
|
||||||
|
"""Return the shape and dtype for the default gradient for a Tensor."""
|
||||||
|
if t.dtype == dtypes.resource:
|
||||||
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||||
|
if (handle_data is None or not handle_data.is_set or
|
||||||
|
len(handle_data.shape_and_type) != 1):
|
||||||
|
return tensor_shape.TensorShape(None), dtypes.float32
|
||||||
|
else:
|
||||||
|
shape_and_type = handle_data.shape_and_type[0]
|
||||||
|
return (tensor_shape.TensorShape(shape_and_type.shape),
|
||||||
|
dtypes.as_dtype(shape_and_type.dtype))
|
||||||
|
return t.shape, t.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def zeros_like(t):
|
||||||
|
"""Like array_ops.zeros_like, but respects resource handles."""
|
||||||
|
if t.dtype == dtypes.resource:
|
||||||
|
return array_ops.zeros(*shape_and_dtype(t))
|
||||||
|
else:
|
||||||
|
return array_ops.zeros_like(t)
|
||||||
|
@ -74,6 +74,7 @@ limitations under the License.
|
|||||||
%rename("%s") TFE_Py_TapeSetIsStopped;
|
%rename("%s") TFE_Py_TapeSetIsStopped;
|
||||||
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
||||||
%rename("%s") TFE_Py_TapeSetShouldRecord;
|
%rename("%s") TFE_Py_TapeSetShouldRecord;
|
||||||
|
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
|
||||||
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
||||||
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
||||||
%rename("%s") TFE_Py_TapeGradient;
|
%rename("%s") TFE_Py_TapeGradient;
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.saved_model import function_deserialization
|
from tensorflow.python.saved_model import function_deserialization
|
||||||
@ -174,6 +175,18 @@ class Loader(object):
|
|||||||
for bound_input, internal_capture in zip(
|
for bound_input, internal_capture in zip(
|
||||||
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
||||||
concrete_function.graph.captures[bound_input] = internal_capture
|
concrete_function.graph.captures[bound_input] = internal_capture
|
||||||
|
if internal_capture.dtype == dtypes.resource:
|
||||||
|
if resource_variable_ops.is_resource_variable(bound_input):
|
||||||
|
try:
|
||||||
|
handle = bound_input.handle
|
||||||
|
except ValueError:
|
||||||
|
# For mirrored variables we'll copy handle data for components
|
||||||
|
# as they get captured.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
custom_gradient.copy_handle_data(handle, internal_capture)
|
||||||
|
else:
|
||||||
|
custom_gradient.copy_handle_data(bound_input, internal_capture)
|
||||||
# Setting "captures" first means "capture" won't create a new
|
# Setting "captures" first means "capture" won't create a new
|
||||||
# placeholder for this input.
|
# placeholder for this input.
|
||||||
concrete_function.graph.capture(bound_input)
|
concrete_function.graph.capture(bound_input)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
import functools
|
import functools
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as defun
|
from tensorflow.python.eager import function as defun
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -172,14 +173,21 @@ class CapturableResourceDeleter(object):
|
|||||||
|
|
||||||
def __init__(self, destroy_resource_fn=None):
|
def __init__(self, destroy_resource_fn=None):
|
||||||
if destroy_resource_fn:
|
if destroy_resource_fn:
|
||||||
self.destroy_resource = destroy_resource_fn
|
self._destroy_resource = destroy_resource_fn
|
||||||
|
self._destruction_context = (
|
||||||
|
context.eager_mode if context.executing_eagerly()
|
||||||
|
else ops.get_default_graph().as_default)
|
||||||
|
else:
|
||||||
|
self._destroy_resource = None
|
||||||
|
|
||||||
def destroy_resource(self):
|
def destroy_resource(self):
|
||||||
"""A function that destroys the resource."""
|
if self._destroy_resource:
|
||||||
pass
|
return self._destroy_resource()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.destroy_resource()
|
if self._destroy_resource:
|
||||||
|
with self._destruction_context():
|
||||||
|
self._destroy_resource()
|
||||||
|
|
||||||
|
|
||||||
class CapturableResource(base.Trackable):
|
class CapturableResource(base.Trackable):
|
||||||
|
Loading…
Reference in New Issue
Block a user