From bd4feec2524aae31df5c7fce1f572f8d0fd8ca43 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 2 Jul 2019 17:02:39 -0700 Subject: [PATCH] Accept output gradients of side outputs when calling functions Fixes higher-order gradients of function calls When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested. When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request. PiperOrigin-RevId: 256268751 --- tensorflow/python/eager/function.py | 415 +++++++++++++++--- .../python/eager/function_gradients_test.py | 146 ++++++ tensorflow/python/eager/pywrap_tfe.h | 7 + tensorflow/python/eager/pywrap_tfe_src.cc | 111 ++++- tensorflow/python/ops/default_gradient.py | 24 + tensorflow/python/pywrap_tfe.i | 1 + tensorflow/python/saved_model/load.py | 13 + .../python/training/tracking/tracking.py | 16 +- 8 files changed, 635 insertions(+), 98 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 6a7987655da..f8fd53ec83d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import enum # pylint: disable=g-bad-import-order import functools import itertools import threading @@ -40,13 +41,16 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient +from tensorflow.python.ops import default_gradient from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_util from tensorflow.python.ops import resource_variable_ops @@ -390,7 +394,8 @@ class _EagerDefinedFunction(object): self._output_types = [o.type for o in self.signature.output_arg] self._output_shapes = [o.shape for o in outputs] self._control_captures = graph.control_captures - self._func_graph_outputs = outputs + # Shallow copy outputs since ConcreteFunction may mutate it. + self._func_graph_outputs = list(outputs) self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) @@ -481,6 +486,13 @@ class _EagerDefinedFunction(object): return outputs +class _PossibleTapeGradientTypes(enum.Enum): + """Represents the output of TFE_Py_TapeSetPossibleGradientTypes.""" + NONE = 0 + FIRST_ORDER = 1 + HIGHER_ORDER = 2 + + class ConcreteFunction(object): """Callable object encapsulating a function definition and its gradient. @@ -517,7 +529,27 @@ class ConcreteFunction(object): self._inference_function = _EagerDefinedFunction( _inference_name(self._func_graph.name), self._func_graph, self._func_graph.inputs, self._func_graph.outputs, self._attrs) - self._backward_graph_function = None + + # When graph building without a tape active, symbolic gradients rely on + # regenerating the backward function for higher-order gradients (to account + # for new side outputs of the rewritten forward function call). Thus there + # is no fixed backward function for this case. However, when a tape is + # active (eager or graph building), we generate fixed backward and forward + # functions at forward function call time. + # + # This difference between the tape and non-tape cases is to avoid building + # unneeded backward functions while graph building (where we may or may not + # eventually need gradients). + self._tape_forward_function_first_order = None + self._tape_backward_function_first_order = None + self._tape_forward_function_higher_order = None + self._tape_backward_function_higher_order = None + + # A map from the number of forward function outputs with accepted gradients + # to backward functions, used to cache non-tape backward function + # generation. + self._cached_graph_backprop_functions = {} + self._signature = signature self._gradient_name = None @@ -673,12 +705,28 @@ class ConcreteFunction(object): "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captured_inputs - if (tape.should_record(tensor_inputs) or - tape.should_record(captured_inputs)): - if context.executing_eagerly(): - return self._eager_backprop_call(args) - else: - return self._backprop_call_with_delayed_rewrite(args) + possible_gradient_type = _PossibleTapeGradientTypes( + pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args)) + if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER: + # There is a single non-persistent tape active, so the user can only + # request first-order gradients from a tape. We can spend less time graph + # building since we know this. + # + # We may still end up computing higher-order gradients, but that'd be + # through `tf.gradients`, which can re-write the forward pass and so needs + # no preparation here. + forward_function, backward_function = ( + self._tape_functions_for_first_order()) + return self._tape_backprop_call(args, forward_function, backward_function) + elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER: + # Either there's a persistent tape watching, or there are multiple nested + # tapes. Either way, the user may request higher-order gradients. We'll + # spend a bit more time and make sure higher-order gradients are correct. + forward_function, backward_function = ( + self._tape_functions_for_higher_order()) + return self._tape_backprop_call(args, forward_function, backward_function) + # else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no + # tape is recording. # Only need to override the gradient in graph mode and when we have outputs. if context.executing_eagerly() or not self.outputs: @@ -708,30 +756,39 @@ class ConcreteFunction(object): def _grad_fn(self, op, *doutputs): """Gradients of this function.""" - if self._backward_graph_function is None: - self._construct_backprop_function() + backwards_function = self._graph_backprop_function(len(doutputs)) + self._forward_function.add_to_graph(op.graph) # pylint: disable=protected-access - self._forward_function.add_to_graph(op.graph) - num_inference_outputs = self._inference_function._num_outputs - # Rewrite an inference call op to be a forward call op - if op.get_attr("f").name.encode() == self._inference_function.name: - op._set_func_attr("f", self._forward_function.name) - op._set_type_list_attr("Tout", self._forward_function._output_types) - op._add_outputs( - self._forward_function._output_types[num_inference_outputs:], - self._forward_function._output_shapes[num_inference_outputs:]) - for i in range(num_inference_outputs, len(op.outputs)): - func_graph_output = self._forward_function._func_graph_outputs[i] - custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) + op._set_func_attr("f", self._forward_function.name) + op._set_type_list_attr("Tout", self._forward_function._output_types) + op._add_outputs( + self._forward_function._output_types[len(op.outputs):], + self._forward_function._output_shapes[len(op.outputs):]) + for i in range(len(op.outputs)): + func_graph_output = self._forward_function._func_graph_outputs[i] + custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) # pylint: enable=protected-access + + capture_mapping = dict(zip(self._func_graph.outputs, op.outputs)) + remapped_captures = [] + for capture in backwards_function.captured_inputs: + remapped_captures.append(capture_mapping.get(capture, capture)) + + # Replace Nones with zeros since we're calling a graph function which + # expects numeric inputs. + cleaned_doutputs = [] + for doutput, placeholder in zip(doutputs, self._func_graph.outputs): + if gradients_util.IsTrainable(placeholder): + if doutput is not None: + cleaned_doutputs.append(doutput) + else: + cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) + # Compute the gradients using the side outputs - side_outputs = op.outputs[num_inference_outputs:] - args = list(doutputs[:num_inference_outputs]) + list(side_outputs) - return self._backward_graph_function._call_flat( # pylint: disable=protected-access - (a for a in args if a is not None), - self._backward_graph_function.captured_inputs) + return backwards_function._call_flat( # pylint: disable=protected-access + cleaned_doutputs, remapped_captures) @property def name(self): @@ -820,16 +877,190 @@ class ConcreteFunction(object): # 2. Otherwise, defun will create two functions, one for forward pass, # and the backward pass will be created via tape. # When registering the function, we register both cases. - if self._backward_graph_function is None: - self._construct_backprop_function() + backward_function = self._graph_backprop_function()._inference_function forward_function = self._forward_function - backward_function = self._backward_graph_function._inference_function # pylint: enable=protected-access forward_function.add_to_graph(g) backward_function.add_to_graph(g) - def _construct_backprop_function(self): - """Constructs the backprop function object for this function.""" + def _graph_backprop_function(self, num_doutputs=None): + """A possibly-cached backprop function.""" + backward_function = self._cached_graph_backprop_functions.get( + num_doutputs, None) + if backward_function is not None: + return backward_function + backward_function = self._construct_graph_backprop_function(num_doutputs) + self._cached_graph_backprop_functions[num_doutputs] = backward_function + return backward_function + + def _construct_graph_backprop_function(self, num_doutputs=None): + """Constructs a backprop function object for this function. + + Args: + num_doutputs: The constructed backprop function will take output gradients + for the first `num_doutputs` outputs of the forward function. Defaults + to the number of outputs for the inference function, but when + higher-order gradients are computed this will increase to include side + outputs. + + Returns: + A backward function taking `num_doutputs` arguments and returning + gradients with respect to inputs of the forward function. + + self._forward_function is re-generated to account for new side outputs, if + any extra were required when building the backward pass. + """ + if num_doutputs is None: + num_doutputs = len(self._inference_function.signature.output_arg) + trainable_outputs = [ + output for output in self._func_graph.outputs[:num_doutputs] + if gradients_util.IsTrainable(output)] + + signature = [] + for t in trainable_outputs: + signature.append( + tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) + + def _backprop_function(*grad_ys): + return gradients_util._GradientsHelper( # pylint: disable=protected-access + trainable_outputs, + self._func_graph.inputs, + grad_ys=grad_ys, + src_graph=self._func_graph) + + with self._func_graph.as_default(): + backwards_graph = func_graph_module.FuncGraph( + _backward_name(self._func_graph.name)) + func_graph_module.func_graph_from_py_func( + name=backwards_graph.name, + python_func=_backprop_function, + args=[], kwargs={}, + signature=signature, + func_graph=backwards_graph) + backwards_graph_captures = list(backwards_graph.captures.keys()) + captures_from_forward = [ + c for c in backwards_graph_captures if + not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] + + forward_function_name = _forward_name(self._func_graph.name) + + existing_outputs = set(self._func_graph.outputs) + for capture in captures_from_forward: + if capture not in existing_outputs: + existing_outputs.add(capture) + self._func_graph.outputs.append(capture) + backward_function_attr = _parse_func_attrs( + {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) + backward_function_attr.update(self._attrs) + + backward_function = ConcreteFunction( + backwards_graph, attrs=backward_function_attr) + forward_function_attr = _parse_func_attrs({ + BACKWARD_FUNCTION_ATTRIBUTE_NAME: + backward_function._inference_function.name}) # pylint: disable=protected-access + forward_function_attr.update(self._attrs) + + self._forward_function = _EagerDefinedFunction( + forward_function_name, self._func_graph, self._func_graph.inputs, + self._func_graph.outputs, forward_function_attr) + return backward_function + + def _tape_functions_for_first_order(self): + """Shortcut for when only first-order gradients are required. + + The returned backward function does not accept gradients with respect to + side output of forward_function. This is fine as long as the user can't + possibly request second order tape gradients, as when they've used a single + non-persistent GradientTape. Since we don't need the backward function to + take gradients with respect to side outputs, we can skip some potentially + slow graph building. + + Returns: + A tuple of (forward_function, backward_function): + forward_function: Takes the same inputs as the inference function, but + returns side outputs used by backward_function in addition to the + inference function's outputs. + backward_function: Takes side outputs from forward_function and + gradients with respect to the "real" outputs of forward_function and + returns gradients with respect to the inputs. + """ + if self._tape_forward_function_first_order is not None: + return (self._tape_forward_function_first_order, + self._tape_backward_function_first_order) + outputs = self._func_graph.outputs[ + :len(self._inference_function.signature.output_arg)] + forward_function, backward_function = ( + self._tape_forward_and_backward_functions(outputs)) + self._tape_forward_function_first_order = forward_function + self._tape_backward_function_first_order = backward_function + return forward_function, backward_function + + # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider + # generalizing if so. + def _tape_functions_for_higher_order(self): + """Forward and backward functions suitable for higher-order gradients. + + Unlike `_tape_functions_for_first_order`, the backward function built by + this method accepts gradients for all of the outputs of the returned forward + function, including side outputs. + + Returns: + A tuple of (forward_function, backward_function): + forward_function: Takes the same inputs as the inference function, but + returns side outputs used by backward_function in addition to the + inference function's outputs. + backward_function: Takes side outputs from forward_function and + gradients with respect to all of its outputs, real and side. Returns + gradients with respect to the inputs. + """ + if self._tape_forward_function_higher_order is not None: + return (self._tape_forward_function_higher_order, + self._tape_backward_function_higher_order) + outputs = [] + # First we need to figure out how many side outputs from the forward pass + # will be required. We do this in a temporary graph to avoid actually + # running multiple copies of the backward pass (one per _GradientsHelper + # call). + # + # While computing gradients, the backward function captures Tensors from + # the forward function. We add these as side outputs of the original + # function. However, we then need to accept output gradients with respect + # to these side outputs for higher order gradients to work. Thus we loop + # until the number of outputs of the function stabilizes. Note that this + # is only required for tape gradients, where we need to declare in advance + # all of the forward op's outputs: symbolic gradients with tf.gradients + # instead rely on regenerating backward functions when higher-order + # gradients are requested. + while len(outputs) < len(self._func_graph.outputs): + new_outputs = self._func_graph.outputs[len(outputs):] + outputs = list(self._func_graph.outputs) + self._tape_forward_and_backward_functions(new_outputs) + forward_function, backward_function = ( + self._tape_forward_and_backward_functions(outputs)) + if len(self._func_graph.outputs) != len(outputs): + raise AssertionError( + ("Unexpectedly added new outputs to the forward function when " + "building the backward function: {}").format( + self._func_graph.outputs[len(outputs):])) + self._tape_forward_function_higher_order = forward_function + self._tape_backward_function_higher_order = backward_function + return forward_function, backward_function + + def _tape_forward_and_backward_functions(self, outputs): + """Constructs tape forward and back functions for `outputs`.""" + # First figure out which of `outputs` are trainable. We'll accept gradients + # for each of these in the backward function. + handles_to_variables = {self._func_graph.captures[v.handle]: v + for v in self._func_graph.variables + if v.handle in self._func_graph.captures} + trainable_outputs = [] + for output in outputs: + if gradients_util.IsTrainable(output): + # Swap in the Variable object for resource handles if we can so + # sparse gradients work. + output = handles_to_variables.get(output, output) + trainable_outputs.append(output) + backwards_graph = func_graph_module.FuncGraph( _backward_name(self._func_graph.name)) # Keep track of the forward graph so that if the backwards graph @@ -837,73 +1068,79 @@ class ConcreteFunction(object): # the forward graph. This is an edge case that can only happen with # tf.custom_gradient. backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access - forward_function_name = _forward_name(self._func_graph.name) - outputs = [x for x in self._func_graph.outputs - if gradients_util.IsTrainable(x)] with backwards_graph.as_default(): - gradients_wrt_outputs = [ - graph_placeholder(x.dtype, x.shape) for x in outputs - ] + gradients_wrt_outputs = [] + for output in trainable_outputs: + gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( + output) + gradients_wrt_outputs.append( + graph_placeholder(gradient_dtype, gradient_shape)) gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access - outputs, + trainable_outputs, self._func_graph.inputs, grad_ys=gradients_wrt_outputs, src_graph=self._func_graph) - backwards_graph_captures = list(backwards_graph.captures.keys()) + captures_from_forward = [ + c for c in backwards_graph.captures.keys() if + not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] + existing_outputs = set(self._func_graph.outputs) + for capture in captures_from_forward: + if capture not in existing_outputs: + existing_outputs.add(capture) + self._func_graph.outputs.append(capture) + forward_function_name = _forward_name(self._func_graph.name) backward_function_attr = _parse_func_attrs( {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) backward_function_attr.update(self._attrs) # The ordering of `backwards_graph.inputs` is important: inputs of - # `self._backward_graph_function` correspond to outputs of - # `self._forward_function`. - backwards_graph.inputs = gradients_wrt_outputs + list( - backwards_graph.captures.values()) - # Clear captures, since we pass them in as inputs. - backwards_graph.captures = {} + # `backward_function` correspond to outputs (including + # side outputs) of `self._tape_forward_function`. + backwards_graph.inputs = ( + gradients_wrt_outputs + list(backwards_graph.captures.values())) backwards_graph.outputs.extend( grad for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) if grad is not None) backwards_graph.structured_outputs = gradients_wrt_inputs - self._backward_graph_function = ConcreteFunction( + backward_function = ConcreteFunction( backwards_graph, attrs=backward_function_attr) forward_function_attr = _parse_func_attrs({ BACKWARD_FUNCTION_ATTRIBUTE_NAME: - self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access + backward_function._inference_function.name}) # pylint: disable=protected-access forward_function_attr.update(self._attrs) - self._forward_function = _EagerDefinedFunction( + + forward_function = _EagerDefinedFunction( forward_function_name, self._func_graph, self._func_graph.inputs, - self._func_graph.outputs + backwards_graph_captures, + self._func_graph.outputs, forward_function_attr) + return forward_function, backward_function - def _eager_backprop_call(self, args): + def _tape_backprop_call(self, args, forward_function, backward_function): """Calls the forward function and records the result on a tape. - This method fully constructs the forward and backward functions before - calling the function and recording them on the tape. - - (Only records results on a tape if the function has outputs). - Args: args: All inputs to the function, including resolved captured inputs + forward_function: The forward pass, outputting both user-specified and + side outputs. + backward_function: Computes gradients for inputs of forward_function given + output gradients for the first `N` of forward_function's outputs, not + necessarily all of them. See `_tape_functions_for_first_order` and + `_tape_functions_for_higher_order`. Returns: The call output. """ - if self._backward_graph_function is None: - self._construct_backprop_function() - ctx = context.context() self._register_gradient() with ops.get_default_graph().gradient_override_map( {"PartitionedCall": self._gradient_name, "StatefulPartitionedCall": self._gradient_name}): - outputs = self._forward_function.call(ctx, args) + outputs = forward_function.call(ctx, args) if isinstance(outputs, ops.Operation) or outputs is None: return outputs @@ -912,19 +1149,55 @@ class ConcreteFunction(object): # `side_outputs` are the intermediate Tensors that were added as outputs to # the forward graph function so that we can compute its gradient. real_outputs = outputs[:self._num_outputs] - skip_positions = [i for i, t in enumerate(real_outputs) - if not gradients_util.IsTrainable(t)] - side_outputs = outputs[self._num_outputs:] - def backward_function(*args): - args = [a for i, a in enumerate(args) - if a is not None and i not in skip_positions] - return self._backward_graph_function._call_flat( # pylint: disable=protected-access - list(args) + side_outputs, - self._backward_graph_function.captured_inputs) + capture_mapping = dict(zip(self._func_graph.outputs, outputs)) + remapped_captures = [ + capture_mapping.get(capture, capture) + for capture in backward_function.captured_inputs] + # We may need to use zeros_like to get a zero for variant Tensors with + # unconnected gradients. We do that in advance so we don't have to hold on + # to the outputs themselves, which may not be needed otherwise. + variant_zeros_like = {} + backward_function_inputs = ( + len(backward_function.inputs) - len(backward_function.captured_inputs)) + recorded_outputs = [] + trainable_recorded_outputs = 0 + skip_positions = [] + for output_index, output in enumerate(outputs): + if trainable_recorded_outputs < backward_function_inputs: + recorded_outputs.append(output) + if gradients_util.IsTrainable(output): + trainable_recorded_outputs += 1 + else: + skip_positions.append(output_index) + if output.dtype == dtypes.variant: + variant_zeros_like[output_index] = default_gradient.zeros_like(output) - tape.record_operation(self._forward_function.signature.name, real_outputs, - args, backward_function) + def _backward_function_wrapper(*args): + """Process output gradients and call the backward function.""" + processed_args = [] + input_index = 0 + for output_index, arg in enumerate(args): + if output_index in skip_positions: + continue + if arg is None: + # We're calling a (non-polymorphic) ConcreteFunction, so we need to + # have a Tensor value for each Tensor we thought would be trainable + # based on its dtype, even if it ended up being unconnected. + input_placeholder = backward_function.inputs[ + input_index] + if input_placeholder.dtype == dtypes.variant: + arg = variant_zeros_like[output_index] + else: + arg = array_ops.zeros( + *default_gradient.shape_and_dtype(input_placeholder)) + processed_args.append(arg) + input_index += 1 + return backward_function._call_flat( # pylint: disable=protected-access + processed_args, remapped_captures) + + tape.record_operation(forward_function.signature.name, + recorded_outputs, args, _backward_function_wrapper) return self._build_call_outputs(real_outputs) def _backprop_call_with_delayed_rewrite(self, args): diff --git a/tensorflow/python/eager/function_gradients_test.py b/tensorflow/python/eager/function_gradients_test.py index 98dec0b361b..09227c3a02e 100644 --- a/tensorflow/python/eager/function_gradients_test.py +++ b/tensorflow/python/eager/function_gradients_test.py @@ -40,6 +40,13 @@ from tensorflow.python.platform import test from tensorflow.python.util import nest +_COS_DERIVATIVES = [math_ops.cos, + lambda x: -math_ops.sin(x), + lambda x: -math_ops.cos(x), + math_ops.sin, + math_ops.cos] + + class FunctionGradientsTest(test.TestCase, parameterized.TestCase): def testGraphModeWithGradients(self): @@ -68,6 +75,145 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(grads.eval(), 2.0) self.assertEqual(grads.shape, v.shape) + def testSymbolicHigherOrder(self): + @def_function.function + def f(x, order): + y = def_function.function(lambda: math_ops.cos(x))() + for _ in range(order): + y, = gradients_impl.gradients(y, [x]) + return y + for order, expected in enumerate(_COS_DERIVATIVES): + self.assertAllClose( + expected(constant_op.constant(1.)), + f(constant_op.constant(1.), order)) + + @parameterized.parameters([dict(persistent=True), + dict(persistent=False)]) + def testSymbolicHigherOrderUnderTape(self, persistent): + @def_function.function + def f(x, order): + with backprop.GradientTape(persistent=persistent) as tape: + tape.watch(x) + # Note that having a tape active, even if we don't use it, forces us + # down a different function call path. Symbolic gradients should work + # here too; correctness of tape gradients are tested elsewhere. + y = def_function.function(lambda: math_ops.cos(x))() + tape_dy = tape.gradient(y, x) + for _ in range(order): + y, = gradients_impl.gradients(y, [x]) + if order > 0: + y1 = tape_dy + for _ in range(order - 1): + y1, = gradients_impl.gradients(y1, [x]) + else: + y1 = y + return y, y1 + for order, expected_f in enumerate(_COS_DERIVATIVES): + expected = self.evaluate(expected_f(constant_op.constant(1.))) + self.assertAllClose( + (expected, expected), + f(constant_op.constant(1.), order)) + + def testIteratedGradientsNested(self): + + def _grad(f): + def _grad_function(primal): + with backprop.GradientTape() as tape: + tape.watch(primal) + primal_out = f(primal) + return tape.gradient(primal_out, primal) + return _grad_function + + @def_function.function + def _forward(x): + return math_ops.cos(x) + + f = _forward + traced_f = def_function.function(f) + one = constant_op.constant(1.) + for expected in _COS_DERIVATIVES: + self.assertAllClose(expected(one), f(one)) + self.assertAllClose(expected(one), traced_f(one)) + self.assertAllClose(expected(one), def_function.function(f)(one)) + f = _grad(f) + traced_f = def_function.function(_grad(traced_f)) + + def testIteratedGradientsNestedWithVariable(self): + + def _grad(f): + def _grad_function(): + with backprop.GradientTape() as tape: + primal_out = f() + g, = tape.gradient(primal_out, tape.watched_variables()) + return g + return _grad_function + + v = variables.Variable(2.) + + @def_function.function + def _forward(): + return math_ops.cos(v) + + f = _forward + + two = constant_op.constant(2.) + + for expected in _COS_DERIVATIVES: + self.assertAllClose(expected(two), f()) + self.assertAllClose(expected(two), def_function.function(f)()) + f = _grad(f) + + def testIteratedGradientsPersistent(self): + + @def_function.function + def _forward(z): + return math_ops.cos(z) + + f = _forward + with backprop.GradientTape(persistent=True) as tape: + start = constant_op.constant(1.) + tape.watch(start) + x = f(start) + for expected in _COS_DERIVATIVES: + self.assertAllClose(expected(start), x) + x = tape.gradient(x, start) + + def testHigherOrderWithVariable(self): + + v = variables.Variable(1.) + + @def_function.function + def _forward(): + return math_ops.cos(v) + + f = _forward + with backprop.GradientTape(persistent=True) as tape: + x = f() + for expected in _COS_DERIVATIVES: + self.assertAllClose(expected(constant_op.constant(1.)), x) + x, = tape.gradient(x, tape.watched_variables()) + + def testGradientsChained(self): + + @def_function.function + def _forward(z): + return math_ops.cos(z) + + f = _forward + x = constant_op.constant(1.) + with backprop.GradientTape() as t: + t.watch(x) + y = f(x) + with backprop.GradientTape() as tt: + doutputs = constant_op.constant(2.) + tt.watch(doutputs) + g = t.gradient(y, x, doutputs) + self.assertAllClose(-2. * math_ops.sin(x), g) + gg = tt.gradient(g, doutputs) + # We're taking gradients with respect to doutputs, which is just a linear + # function of the gradient. + self.assertAllClose(-math_ops.sin(x), gg) + def testSymGradGatherNd(self): with ops.Graph().as_default(), self.cached_session() as sess: diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 0af0d94d751..40c72ef6fc4 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -148,6 +148,13 @@ void TFE_Py_TapeSetAdd(PyObject* tape); PyObject* TFE_Py_TapeSetIsEmpty(); PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); + +// Like TFE_Py_TapeSetShouldRecord but with a ternary return: +// - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecord is false) +// - 1 if first-order gradients may be requested +// - 2 if higher-order gradients may be requested +PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors); + void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor); void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index c80cd5a29f6..4b959d9f17e 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -695,6 +695,14 @@ void SetOpAttrWithDefaults( } } +PyObject* GetPythonObjectFromInt(int num) { +#if PY_MAJOR_VERSION >= 3 + return PyLong_FromLong(num); +#else + return PyInt_FromLong(num); +#endif +} + // Python subclass of Exception that is created on not ok Status. tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; @@ -1500,33 +1508,51 @@ static std::vector MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { +// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be +// null. Returns true on success and false on a Python exception. +bool TensorShapesAndDtypes(PyObject* tensors, + std::vector* tensor_ids, + std::vector* dtypes) { + tensorflow::Safe_PyObjectPtr seq( + PySequence_Fast(tensors, "expected a sequence")); + if (seq == nullptr) { + return false; + } + int len = PySequence_Fast_GET_SIZE(seq.get()); + tensor_ids->reserve(len); + dtypes->reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); + tensor_ids->push_back(FastTensorId(item)); + dtypes->push_back(FastTensorDtype(item)); + } + return true; +} + +bool TapeCouldPossiblyRecord(PyObject* tensors) { if (tensors == Py_None) { - Py_RETURN_FALSE; + return false; } if (*ThreadTapeIsStopped()) { - Py_RETURN_FALSE; + return false; } if (!HasTape()) { + return false; + } + return true; +} + +PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { + if (!TapeCouldPossiblyRecord(tensors)) { Py_RETURN_FALSE; } - PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); - if (seq == nullptr) { - return nullptr; - } - int len = PySequence_Fast_GET_SIZE(seq); // TODO(apassos) consider not building a list and changing the API to check // each tensor individually. std::vector tensor_ids; std::vector dtypes; - tensor_ids.reserve(len); - dtypes.reserve(len); - for (int i = 0; i < len; ++i) { - PyObject* item = PySequence_Fast_GET_ITEM(seq, i); - tensor_ids.push_back(FastTensorId(item)); - dtypes.push_back(FastTensorDtype(item)); + if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { + return nullptr; } - Py_DECREF(seq); auto tape_set = *GetTapeSet(); for (TFE_Py_Tape* tape : tape_set) { if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { @@ -1543,6 +1569,53 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { Py_RETURN_FALSE; } +PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { + if (!TapeCouldPossiblyRecord(tensors)) { + return GetPythonObjectFromInt(0); + } + std::vector tensor_ids; + std::vector dtypes; + if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { + return nullptr; + } + + // If there is a persistent tape watching, or if there are multiple tapes + // watching, we'll return immediately indicating that higher-order tape + // gradients are possible. + bool some_tape_watching = false; + auto tape_set = *GetTapeSet(); + for (TFE_Py_Tape* tape : tape_set) { + if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { + if (tape->tape->IsPersistent() || some_tape_watching) { + // Either this is the second tape watching, or this tape is persistent: + // higher-order gradients are possible. + return GetPythonObjectFromInt(2); + } + some_tape_watching = true; + } + } + auto forward_accumulators = *GetAccumulatorSet(); + for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { + if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { + if (some_tape_watching) { + // This is the second tape watching: higher-order gradients are + // possible. Note that there's no equivalent of persistence for + // forward-mode. + return GetPythonObjectFromInt(2); + } + some_tape_watching = true; + } + } + if (some_tape_watching) { + // There's exactly one non-persistent tape. The user can request first-order + // gradients but won't be able to get higher-order tape gradients. + return GetPythonObjectFromInt(1); + } else { + // There are no tapes. The user can't request tape gradients. + return GetPythonObjectFromInt(0); + } +} + void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { if (*ThreadTapeIsStopped()) { return; @@ -1997,14 +2070,6 @@ PyObject* GetPythonObjectFromString(const char* s) { #endif } -PyObject* GetPythonObjectFromInt(int num) { -#if PY_MAJOR_VERSION >= 3 - return PyLong_FromLong(num); -#else - return PyInt_FromLong(num); -#endif -} - bool CheckResourceVariable(PyObject* item) { if (PyObject_TypeCheck(item, resource_variable_type)) { tensorflow::Safe_PyObjectPtr handle( diff --git a/tensorflow/python/ops/default_gradient.py b/tensorflow/python/ops/default_gradient.py index 1742d92c053..1662f7e71ad 100644 --- a/tensorflow/python/ops/default_gradient.py +++ b/tensorflow/python/ops/default_gradient.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -33,3 +35,25 @@ def get_zeros_dtype(t): else: return handle_data.shape_and_type[0].dtype return t.dtype + + +def shape_and_dtype(t): + """Return the shape and dtype for the default gradient for a Tensor.""" + if t.dtype == dtypes.resource: + handle_data = resource_variable_ops.get_eager_safe_handle_data(t) + if (handle_data is None or not handle_data.is_set or + len(handle_data.shape_and_type) != 1): + return tensor_shape.TensorShape(None), dtypes.float32 + else: + shape_and_type = handle_data.shape_and_type[0] + return (tensor_shape.TensorShape(shape_and_type.shape), + dtypes.as_dtype(shape_and_type.dtype)) + return t.shape, t.dtype + + +def zeros_like(t): + """Like array_ops.zeros_like, but respects resource handles.""" + if t.dtype == dtypes.resource: + return array_ops.zeros(*shape_and_dtype(t)) + else: + return array_ops.zeros_like(t) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index db39cdacbef..6f68a2c0548 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -74,6 +74,7 @@ limitations under the License. %rename("%s") TFE_Py_TapeSetIsStopped; %rename("%s") TFE_Py_TapeSetIsEmpty; %rename("%s") TFE_Py_TapeSetShouldRecord; +%rename("%s") TFE_Py_TapeSetPossibleGradientTypes; %rename("%s") TFE_Py_TapeSetDeleteTrace; %rename("%s") TFE_Py_TapeSetRecordOperation; %rename("%s") TFE_Py_TapeGradient; diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index e9d7b840e8c..69dc43f9b60 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.saved_model import function_deserialization @@ -174,6 +175,18 @@ class Loader(object): for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): concrete_function.graph.captures[bound_input] = internal_capture + if internal_capture.dtype == dtypes.resource: + if resource_variable_ops.is_resource_variable(bound_input): + try: + handle = bound_input.handle + except ValueError: + # For mirrored variables we'll copy handle data for components + # as they get captured. + pass + else: + custom_gradient.copy_handle_data(handle, internal_capture) + else: + custom_gradient.copy_handle_data(bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input) diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index 093afbdc654..d3aaf78bc90 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import weakref +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun from tensorflow.python.framework import dtypes @@ -172,14 +173,21 @@ class CapturableResourceDeleter(object): def __init__(self, destroy_resource_fn=None): if destroy_resource_fn: - self.destroy_resource = destroy_resource_fn + self._destroy_resource = destroy_resource_fn + self._destruction_context = ( + context.eager_mode if context.executing_eagerly() + else ops.get_default_graph().as_default) + else: + self._destroy_resource = None def destroy_resource(self): - """A function that destroys the resource.""" - pass + if self._destroy_resource: + return self._destroy_resource() def __del__(self): - self.destroy_resource() + if self._destroy_resource: + with self._destruction_context(): + self._destroy_resource() class CapturableResource(base.Trackable):