Forwardprop: fix nested forwardprop of non-differentiable ops

A special case was expecting that we didn't have any tangents, when in fact we just want to discard them.

PiperOrigin-RevId: 296318800
Change-Id: I7198596435f294333e00a2dcfe5ac8ec31d0b28c
This commit is contained in:
Allen Lavoie 2020-02-20 16:02:32 -08:00 committed by TensorFlower Gardener
parent 9553f81edf
commit 10aff5d518
2 changed files with 14 additions and 1 deletions

View File

@ -230,6 +230,14 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
))
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
def testNonDifferentiableOpWithInputTangent(self):
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc1:
with forwardprop.ForwardAccumulator(x, 2.) as acc2:
y = array_ops.zeros_like(x)
self.assertIsNone(acc1.jvp(y))
self.assertIsNone(acc2.jvp(y))
def testJVPFunctionUsedByAccumulatorForOps(self):
previous_fn = forwardprop._jvp_dispatch
try:

View File

@ -981,7 +981,7 @@ class _TapeGradientFunctions(object):
self._func_graph.outputs,
forward_function_attr)
if not self._func_graph.outputs or not input_tangents:
if not input_tangents:
# There is no need to special-case forwardprop, so we can return the
# forward+backward pair we've created without further wrapping.
return (forward_function, self._func_graph, backward_function,
@ -1085,6 +1085,11 @@ class _TapeGradientFunctions(object):
"StatefulPartitionedCall": gradient_function}):
forward_outputs = forward_function.call(context.context(),
forward_inputs)
if isinstance(forward_outputs, ops.Operation):
# _wrapped_backward_function expects a list, but if the function has
# no outputs its call() returns an Operation. We need to undo that
# so we don't cause problems later.
forward_outputs = []
py_backward, _ = self._wrap_backward_function(
self._func_graph, backward_function, forward_outputs)
# We will never request backward tape gradients for this operation