Forwardprop: fix nested forwardprop of non-differentiable ops
A special case was expecting that we didn't have any tangents, when in fact we just want to discard them. PiperOrigin-RevId: 296318800 Change-Id: I7198596435f294333e00a2dcfe5ac8ec31d0b28c
This commit is contained in:
parent
9553f81edf
commit
10aff5d518
@ -230,6 +230,14 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
))
|
||||
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
|
||||
|
||||
def testNonDifferentiableOpWithInputTangent(self):
|
||||
x = constant_op.constant(1.)
|
||||
with forwardprop.ForwardAccumulator(x, 2.) as acc1:
|
||||
with forwardprop.ForwardAccumulator(x, 2.) as acc2:
|
||||
y = array_ops.zeros_like(x)
|
||||
self.assertIsNone(acc1.jvp(y))
|
||||
self.assertIsNone(acc2.jvp(y))
|
||||
|
||||
def testJVPFunctionUsedByAccumulatorForOps(self):
|
||||
previous_fn = forwardprop._jvp_dispatch
|
||||
try:
|
||||
|
@ -981,7 +981,7 @@ class _TapeGradientFunctions(object):
|
||||
self._func_graph.outputs,
|
||||
forward_function_attr)
|
||||
|
||||
if not self._func_graph.outputs or not input_tangents:
|
||||
if not input_tangents:
|
||||
# There is no need to special-case forwardprop, so we can return the
|
||||
# forward+backward pair we've created without further wrapping.
|
||||
return (forward_function, self._func_graph, backward_function,
|
||||
@ -1085,6 +1085,11 @@ class _TapeGradientFunctions(object):
|
||||
"StatefulPartitionedCall": gradient_function}):
|
||||
forward_outputs = forward_function.call(context.context(),
|
||||
forward_inputs)
|
||||
if isinstance(forward_outputs, ops.Operation):
|
||||
# _wrapped_backward_function expects a list, but if the function has
|
||||
# no outputs its call() returns an Operation. We need to undo that
|
||||
# so we don't cause problems later.
|
||||
forward_outputs = []
|
||||
py_backward, _ = self._wrap_backward_function(
|
||||
self._func_graph, backward_function, forward_outputs)
|
||||
# We will never request backward tape gradients for this operation
|
||||
|
Loading…
x
Reference in New Issue
Block a user