Fix higher-order tape gradients of cond and case
Piggybacks on the tf.function tape interface logic: higher-order non-tape gradients work, and tf.function takes non-tape gradients of its contents. The same fix applies to While, but the test I have in mind needs another fix before it's viable. Starting small here since cond is easier. PiperOrigin-RevId: 337924627 Change-Id: Ife7e05a2c0818f6310c4cad19ec4fc46c8382000
This commit is contained in:
parent
14c3950397
commit
3ce466a482
@ -1418,13 +1418,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
|
||||
num_output_tangents)
|
||||
|
||||
|
||||
# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
|
||||
# unfortunately too slow to use here.
|
||||
_POSSIBLE_GRADIENT_TYPES_NONE = 0
|
||||
_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
|
||||
_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
|
||||
|
||||
|
||||
class _ForwardBackwardCall(object):
|
||||
"""Holds the state of a function call between execution and recording."""
|
||||
|
||||
@ -1918,9 +1911,8 @@ class ConcreteFunction(object):
|
||||
"on invocation of %s, the %d-th input (%s) was not a "
|
||||
"Tensor." % (self._func_graph.name, i, str(arg)))
|
||||
args = tensor_inputs + captured_inputs
|
||||
possible_gradient_type = (
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
|
||||
possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
|
||||
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
|
||||
and executing_eagerly):
|
||||
# No tape is watching; skip to running the function.
|
||||
return self._build_call_outputs(self._inference_function.call(
|
||||
@ -2080,7 +2072,7 @@ class ConcreteFunction(object):
|
||||
Args:
|
||||
args: A flat list of Tensors with all of the inputs to the forward
|
||||
function (including user-specified and captured inputs).
|
||||
possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*.
|
||||
possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
|
||||
executing_eagerly: Boolean, the value of context.executing_eagerly().
|
||||
|
||||
Returns:
|
||||
@ -2098,7 +2090,8 @@ class ConcreteFunction(object):
|
||||
# Allows re-use of forward and backward function pairs depending on the
|
||||
# tapes and forward accumulators watching its inputs.
|
||||
cache_key = (need_gradients_for_jvps, input_tangents.indices)
|
||||
if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER:
|
||||
if (possible_gradient_type
|
||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
|
||||
if input_tangents.indices or executing_eagerly:
|
||||
# There is a single non-persistent tape active, so the user can only
|
||||
# request first-order gradients from a tape. We can spend less time
|
||||
@ -2129,7 +2122,8 @@ class ConcreteFunction(object):
|
||||
return _ForwardBackwardCall(
|
||||
self._delayed_rewrite_functions, args, input_tangents.tangents,
|
||||
tape_watching=True)
|
||||
elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER:
|
||||
elif (possible_gradient_type
|
||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
|
||||
# Either there's a persistent tape watching, or there are multiple nested
|
||||
# tapes. Either way, the user may request higher-order gradients. We'll
|
||||
# spend a bit more time and make sure higher-order gradients are correct.
|
||||
@ -2144,7 +2138,7 @@ class ConcreteFunction(object):
|
||||
self._higher_order_tape_functions[cache_key] = functions
|
||||
return _ForwardBackwardCall(functions, args, input_tangents.tangents,
|
||||
tape_watching=True)
|
||||
# else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no
|
||||
# else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
|
||||
# tape is recording.
|
||||
return _ForwardBackwardCall(
|
||||
self._delayed_rewrite_functions, args, input_tangents.tangents,
|
||||
|
@ -960,6 +960,42 @@ class CondV2Test(test.TestCase):
|
||||
|
||||
self.assertAllEqual(fn_with_cond(), 12.0)
|
||||
|
||||
def _CheckIteratedCosGradients(self, func):
|
||||
|
||||
def _grad(f):
|
||||
def _grad_function(primal):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primal)
|
||||
primal_out = f(primal)
|
||||
return tape.gradient(primal_out, primal)
|
||||
return _grad_function
|
||||
|
||||
f = func
|
||||
one = constant_op.constant(1.)
|
||||
for expected in [math_ops.cos,
|
||||
lambda x: -math_ops.sin(x),
|
||||
lambda x: -math_ops.cos(x),
|
||||
math_ops.sin,
|
||||
math_ops.cos]:
|
||||
self.assertAllClose(expected(one), def_function.function(f)(one))
|
||||
f = _grad(f)
|
||||
|
||||
def testIteratedGradientsCond(self):
|
||||
def _func(x):
|
||||
return cond_v2.cond_v2(
|
||||
constant_op.constant(True),
|
||||
lambda: math_ops.cos(array_ops.identity(x)),
|
||||
lambda: math_ops.sin(array_ops.identity(x)))
|
||||
self._CheckIteratedCosGradients(_func)
|
||||
|
||||
def testIteratedGradientsCase(self):
|
||||
def _func(x):
|
||||
return cond_v2.indexed_case(
|
||||
constant_op.constant(1),
|
||||
[lambda: math_ops.sin(array_ops.identity(x)),
|
||||
lambda: math_ops.cos(array_ops.identity(x))])
|
||||
self._CheckIteratedCosGradients(_func)
|
||||
|
||||
def testLowering(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
|
@ -26,6 +26,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
||||
return [None] + outputs
|
||||
|
||||
|
||||
def _run_as_function_for_tape_gradients(make_op, cond_inputs):
|
||||
"""Fix higher-order tape gradients by wrapping `make_op` in a function."""
|
||||
# GradientTapes created inside a function currently don't work well with
|
||||
# un-wrapped control flow ops in that same function. Wrapping in an extra
|
||||
# layer of intermediate function means we run extra logic in the function
|
||||
# gradient code to record the correct intermediates on the tape.
|
||||
#
|
||||
# The function attribute inputs to cond/case ops are not hashable, so we pass
|
||||
# everything as a capture to bypass defun's caching.
|
||||
if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
|
||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
|
||||
# We only need one function between the tape and the cond; if we've
|
||||
# already wrapped once, we stop wrapping to avoid infinite recursion.
|
||||
and not (ops.get_default_graph().building_function
|
||||
and "cond_gradient_wrapper" in ops.get_default_graph().name)):
|
||||
|
||||
op = None
|
||||
def _run_make_and_extract_op():
|
||||
# Post-processing happens on the cond op, not the function call op.
|
||||
nonlocal op
|
||||
tensors = make_op()
|
||||
op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
|
||||
return tensors
|
||||
|
||||
return op, function.defun_with_attributes(
|
||||
_run_make_and_extract_op,
|
||||
attributes=dict(func_name="cond_gradient_wrapper"))()
|
||||
else:
|
||||
return _get_op_and_outputs(make_op())
|
||||
|
||||
|
||||
def _build_cond(pred,
|
||||
true_graph,
|
||||
false_graph,
|
||||
@ -268,16 +300,17 @@ def _build_cond(pred,
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_if
|
||||
|
||||
tensors = op_fn(
|
||||
pred,
|
||||
cond_inputs, [t.dtype for t in true_graph.outputs],
|
||||
util.create_new_tf_function(true_graph),
|
||||
util.create_new_tf_function(false_graph),
|
||||
output_shapes=_get_output_shapes(true_graph.outputs,
|
||||
false_graph.outputs),
|
||||
name=name)
|
||||
def make_op():
|
||||
return op_fn(
|
||||
pred,
|
||||
cond_inputs, [t.dtype for t in true_graph.outputs],
|
||||
util.create_new_tf_function(true_graph),
|
||||
util.create_new_tf_function(false_graph),
|
||||
output_shapes=_get_output_shapes(true_graph.outputs,
|
||||
false_graph.outputs),
|
||||
name=name)
|
||||
if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
|
||||
|
||||
if_op, tensors = _get_op_and_outputs(tensors)
|
||||
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
||||
if if_op is not None:
|
||||
if_op._true_graph = true_graph
|
||||
@ -1156,14 +1189,16 @@ def _build_case(branch_index,
|
||||
# Create the Case op.
|
||||
with ops.control_dependencies(
|
||||
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
||||
tensors = op_fn(
|
||||
branch_index,
|
||||
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||
name=name)
|
||||
|
||||
case_op, tensors = _get_op_and_outputs(tensors)
|
||||
def _make_op():
|
||||
return op_fn(
|
||||
branch_index,
|
||||
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||
name=name)
|
||||
case_op, tensors = _run_as_function_for_tape_gradients(
|
||||
_make_op, case_inputs)
|
||||
|
||||
if case_op is not None:
|
||||
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||
|
@ -24,6 +24,7 @@ import contextlib
|
||||
from six.moves import xrange, zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -1007,3 +1008,15 @@ def _AggregatedGrads(grads,
|
||||
# out_grads[i] is [], thus its aggregation is simply None.
|
||||
out_grads[i] = None
|
||||
return out_grads
|
||||
|
||||
|
||||
# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
|
||||
# unfortunately too slow to use here.
|
||||
POSSIBLE_GRADIENT_TYPES_NONE = 0
|
||||
POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
|
||||
POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
|
||||
|
||||
|
||||
def PossibleTapeGradientTypes(tensors):
|
||||
"""Determines whether and how `args` may require tape gradients."""
|
||||
return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)
|
||||
|
Loading…
Reference in New Issue
Block a user