diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index ab32c8370af..9d9cf0b50c3 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1418,13 +1418,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): num_output_tangents) -# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are -# unfortunately too slow to use here. -_POSSIBLE_GRADIENT_TYPES_NONE = 0 -_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 -_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 - - class _ForwardBackwardCall(object): """Holds the state of a function call between execution and recording.""" @@ -1918,9 +1911,8 @@ class ConcreteFunction(object): "on invocation of %s, the %d-th input (%s) was not a " "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captured_inputs - possible_gradient_type = ( - pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args)) - if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE + possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) + if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE and executing_eagerly): # No tape is watching; skip to running the function. return self._build_call_outputs(self._inference_function.call( @@ -2080,7 +2072,7 @@ class ConcreteFunction(object): Args: args: A flat list of Tensors with all of the inputs to the forward function (including user-specified and captured inputs). - possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*. + possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. executing_eagerly: Boolean, the value of context.executing_eagerly(). Returns: @@ -2098,7 +2090,8 @@ class ConcreteFunction(object): # Allows re-use of forward and backward function pairs depending on the # tapes and forward accumulators watching its inputs. cache_key = (need_gradients_for_jvps, input_tangents.indices) - if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER: + if (possible_gradient_type + == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): if input_tangents.indices or executing_eagerly: # There is a single non-persistent tape active, so the user can only # request first-order gradients from a tape. We can spend less time @@ -2129,7 +2122,8 @@ class ConcreteFunction(object): return _ForwardBackwardCall( self._delayed_rewrite_functions, args, input_tangents.tangents, tape_watching=True) - elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER: + elif (possible_gradient_type + == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): # Either there's a persistent tape watching, or there are multiple nested # tapes. Either way, the user may request higher-order gradients. We'll # spend a bit more time and make sure higher-order gradients are correct. @@ -2144,7 +2138,7 @@ class ConcreteFunction(object): self._higher_order_tape_functions[cache_key] = functions return _ForwardBackwardCall(functions, args, input_tangents.tangents, tape_watching=True) - # else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no + # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no # tape is recording. return _ForwardBackwardCall( self._delayed_rewrite_functions, args, input_tangents.tangents, diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 70d7b2530a9..fb60dc23e5e 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -960,6 +960,42 @@ class CondV2Test(test.TestCase): self.assertAllEqual(fn_with_cond(), 12.0) + def _CheckIteratedCosGradients(self, func): + + def _grad(f): + def _grad_function(primal): + with backprop.GradientTape() as tape: + tape.watch(primal) + primal_out = f(primal) + return tape.gradient(primal_out, primal) + return _grad_function + + f = func + one = constant_op.constant(1.) + for expected in [math_ops.cos, + lambda x: -math_ops.sin(x), + lambda x: -math_ops.cos(x), + math_ops.sin, + math_ops.cos]: + self.assertAllClose(expected(one), def_function.function(f)(one)) + f = _grad(f) + + def testIteratedGradientsCond(self): + def _func(x): + return cond_v2.cond_v2( + constant_op.constant(True), + lambda: math_ops.cos(array_ops.identity(x)), + lambda: math_ops.sin(array_ops.identity(x))) + self._CheckIteratedCosGradients(_func) + + def testIteratedGradientsCase(self): + def _func(x): + return cond_v2.indexed_case( + constant_op.constant(1), + [lambda: math_ops.sin(array_ops.identity(x)), + lambda: math_ops.cos(array_ops.identity(x))]) + self._CheckIteratedCosGradients(_func) + def testLowering(self): with ops.Graph().as_default() as g: with self.session(graph=g) as sess: diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 5bdd2494e91..75130fcd8a7 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -26,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.eager import backprop_util +from tensorflow.python.eager import function from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op @@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name return [None] + outputs +def _run_as_function_for_tape_gradients(make_op, cond_inputs): + """Fix higher-order tape gradients by wrapping `make_op` in a function.""" + # GradientTapes created inside a function currently don't work well with + # un-wrapped control flow ops in that same function. Wrapping in an extra + # layer of intermediate function means we run extra logic in the function + # gradient code to record the correct intermediates on the tape. + # + # The function attribute inputs to cond/case ops are not hashable, so we pass + # everything as a capture to bypass defun's caching. + if (gradients_util.PossibleTapeGradientTypes(cond_inputs) + == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER + # We only need one function between the tape and the cond; if we've + # already wrapped once, we stop wrapping to avoid infinite recursion. + and not (ops.get_default_graph().building_function + and "cond_gradient_wrapper" in ops.get_default_graph().name)): + + op = None + def _run_make_and_extract_op(): + # Post-processing happens on the cond op, not the function call op. + nonlocal op + tensors = make_op() + op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable + return tensors + + return op, function.defun_with_attributes( + _run_make_and_extract_op, + attributes=dict(func_name="cond_gradient_wrapper"))() + else: + return _get_op_and_outputs(make_op()) + + def _build_cond(pred, true_graph, false_graph, @@ -268,16 +300,17 @@ def _build_cond(pred, else: op_fn = gen_functional_ops.stateless_if - tensors = op_fn( - pred, - cond_inputs, [t.dtype for t in true_graph.outputs], - util.create_new_tf_function(true_graph), - util.create_new_tf_function(false_graph), - output_shapes=_get_output_shapes(true_graph.outputs, - false_graph.outputs), - name=name) + def make_op(): + return op_fn( + pred, + cond_inputs, [t.dtype for t in true_graph.outputs], + util.create_new_tf_function(true_graph), + util.create_new_tf_function(false_graph), + output_shapes=_get_output_shapes(true_graph.outputs, + false_graph.outputs), + name=name) + if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs) - if_op, tensors = _get_op_and_outputs(tensors) # `if_op` is None if this is a `StatelessIf` op with no outputs. if if_op is not None: if_op._true_graph = true_graph @@ -1156,14 +1189,16 @@ def _build_case(branch_index, # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): - tensors = op_fn( - branch_index, - case_inputs, [t.dtype for t in branch_graphs[0].outputs], - [util.create_new_tf_function(g) for g in branch_graphs], - output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), - name=name) - case_op, tensors = _get_op_and_outputs(tensors) + def _make_op(): + return op_fn( + branch_index, + case_inputs, [t.dtype for t in branch_graphs[0].outputs], + [util.create_new_tf_function(g) for g in branch_graphs], + output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), + name=name) + case_op, tensors = _run_as_function_for_tape_gradients( + _make_op, case_inputs) if case_op is not None: util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 4d4df0ffa48..c356e82ac1f 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -24,6 +24,7 @@ import contextlib from six.moves import xrange, zip # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context @@ -1007,3 +1008,15 @@ def _AggregatedGrads(grads, # out_grads[i] is [], thus its aggregation is simply None. out_grads[i] = None return out_grads + + +# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are +# unfortunately too slow to use here. +POSSIBLE_GRADIENT_TYPES_NONE = 0 +POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 +POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 + + +def PossibleTapeGradientTypes(tensors): + """Determines whether and how `args` may require tape gradients.""" + return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)