Fix higher-order tape gradients of cond and case

Piggybacks on the tf.function tape interface logic: higher-order non-tape gradients work, and tf.function takes non-tape gradients of its contents.

The same fix applies to While, but the test I have in mind needs another fix before it's viable. Starting small here since cond is easier.

PiperOrigin-RevId: 337924627
Change-Id: Ife7e05a2c0818f6310c4cad19ec4fc46c8382000
This commit is contained in:
Allen Lavoie 2020-10-19 13:45:17 -07:00 committed by TensorFlower Gardener
parent 14c3950397
commit 3ce466a482
4 changed files with 108 additions and 30 deletions

View File

@ -1418,13 +1418,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
num_output_tangents)
# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
# unfortunately too slow to use here.
_POSSIBLE_GRADIENT_TYPES_NONE = 0
_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
class _ForwardBackwardCall(object):
"""Holds the state of a function call between execution and recording."""
@ -1918,9 +1911,8 @@ class ConcreteFunction(object):
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs
possible_gradient_type = (
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
and executing_eagerly):
# No tape is watching; skip to running the function.
return self._build_call_outputs(self._inference_function.call(
@ -2080,7 +2072,7 @@ class ConcreteFunction(object):
Args:
args: A flat list of Tensors with all of the inputs to the forward
function (including user-specified and captured inputs).
possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*.
possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
executing_eagerly: Boolean, the value of context.executing_eagerly().
Returns:
@ -2098,7 +2090,8 @@ class ConcreteFunction(object):
# Allows re-use of forward and backward function pairs depending on the
# tapes and forward accumulators watching its inputs.
cache_key = (need_gradients_for_jvps, input_tangents.indices)
if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER:
if (possible_gradient_type
== gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
if input_tangents.indices or executing_eagerly:
# There is a single non-persistent tape active, so the user can only
# request first-order gradients from a tape. We can spend less time
@ -2129,7 +2122,8 @@ class ConcreteFunction(object):
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,
tape_watching=True)
elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER:
elif (possible_gradient_type
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
# Either there's a persistent tape watching, or there are multiple nested
# tapes. Either way, the user may request higher-order gradients. We'll
# spend a bit more time and make sure higher-order gradients are correct.
@ -2144,7 +2138,7 @@ class ConcreteFunction(object):
self._higher_order_tape_functions[cache_key] = functions
return _ForwardBackwardCall(functions, args, input_tangents.tangents,
tape_watching=True)
# else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no
# else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
# tape is recording.
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,

View File

@ -960,6 +960,42 @@ class CondV2Test(test.TestCase):
self.assertAllEqual(fn_with_cond(), 12.0)
def _CheckIteratedCosGradients(self, func):
def _grad(f):
def _grad_function(primal):
with backprop.GradientTape() as tape:
tape.watch(primal)
primal_out = f(primal)
return tape.gradient(primal_out, primal)
return _grad_function
f = func
one = constant_op.constant(1.)
for expected in [math_ops.cos,
lambda x: -math_ops.sin(x),
lambda x: -math_ops.cos(x),
math_ops.sin,
math_ops.cos]:
self.assertAllClose(expected(one), def_function.function(f)(one))
f = _grad(f)
def testIteratedGradientsCond(self):
def _func(x):
return cond_v2.cond_v2(
constant_op.constant(True),
lambda: math_ops.cos(array_ops.identity(x)),
lambda: math_ops.sin(array_ops.identity(x)))
self._CheckIteratedCosGradients(_func)
def testIteratedGradientsCase(self):
def _func(x):
return cond_v2.indexed_case(
constant_op.constant(1),
[lambda: math_ops.sin(array_ops.identity(x)),
lambda: math_ops.cos(array_ops.identity(x))])
self._CheckIteratedCosGradients(_func)
def testLowering(self):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:

View File

@ -26,6 +26,7 @@ from __future__ import print_function
import collections
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
return [None] + outputs
def _run_as_function_for_tape_gradients(make_op, cond_inputs):
"""Fix higher-order tape gradients by wrapping `make_op` in a function."""
# GradientTapes created inside a function currently don't work well with
# un-wrapped control flow ops in that same function. Wrapping in an extra
# layer of intermediate function means we run extra logic in the function
# gradient code to record the correct intermediates on the tape.
#
# The function attribute inputs to cond/case ops are not hashable, so we pass
# everything as a capture to bypass defun's caching.
if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
# We only need one function between the tape and the cond; if we've
# already wrapped once, we stop wrapping to avoid infinite recursion.
and not (ops.get_default_graph().building_function
and "cond_gradient_wrapper" in ops.get_default_graph().name)):
op = None
def _run_make_and_extract_op():
# Post-processing happens on the cond op, not the function call op.
nonlocal op
tensors = make_op()
op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
return tensors
return op, function.defun_with_attributes(
_run_make_and_extract_op,
attributes=dict(func_name="cond_gradient_wrapper"))()
else:
return _get_op_and_outputs(make_op())
def _build_cond(pred,
true_graph,
false_graph,
@ -268,16 +300,17 @@ def _build_cond(pred,
else:
op_fn = gen_functional_ops.stateless_if
tensors = op_fn(
pred,
cond_inputs, [t.dtype for t in true_graph.outputs],
util.create_new_tf_function(true_graph),
util.create_new_tf_function(false_graph),
output_shapes=_get_output_shapes(true_graph.outputs,
false_graph.outputs),
name=name)
def make_op():
return op_fn(
pred,
cond_inputs, [t.dtype for t in true_graph.outputs],
util.create_new_tf_function(true_graph),
util.create_new_tf_function(false_graph),
output_shapes=_get_output_shapes(true_graph.outputs,
false_graph.outputs),
name=name)
if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
if_op, tensors = _get_op_and_outputs(tensors)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
if_op._true_graph = true_graph
@ -1156,14 +1189,16 @@ def _build_case(branch_index,
# Create the Case op.
with ops.control_dependencies(
sum((list(bg.control_captures) for bg in branch_graphs), [])):
tensors = op_fn(
branch_index,
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name)
case_op, tensors = _get_op_and_outputs(tensors)
def _make_op():
return op_fn(
branch_index,
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name)
case_op, tensors = _run_as_function_for_tape_gradients(
_make_op, case_inputs)
if case_op is not None:
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)

View File

@ -24,6 +24,7 @@ import contextlib
from six.moves import xrange, zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@ -1007,3 +1008,15 @@ def _AggregatedGrads(grads,
# out_grads[i] is [], thus its aggregation is simply None.
out_grads[i] = None
return out_grads
# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
# unfortunately too slow to use here.
POSSIBLE_GRADIENT_TYPES_NONE = 0
POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
def PossibleTapeGradientTypes(tensors):
"""Determines whether and how `args` may require tape gradients."""
return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)