Forwardprop: fix nested forwardprop of non-differentiable ops
A special case was expecting that we didn't have any tangents, when in fact we just want to discard them. PiperOrigin-RevId: 296318800 Change-Id: I7198596435f294333e00a2dcfe5ac8ec31d0b28c
This commit is contained in:
parent
9553f81edf
commit
10aff5d518
@ -230,6 +230,14 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
))
|
))
|
||||||
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
|
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
|
||||||
|
|
||||||
|
def testNonDifferentiableOpWithInputTangent(self):
|
||||||
|
x = constant_op.constant(1.)
|
||||||
|
with forwardprop.ForwardAccumulator(x, 2.) as acc1:
|
||||||
|
with forwardprop.ForwardAccumulator(x, 2.) as acc2:
|
||||||
|
y = array_ops.zeros_like(x)
|
||||||
|
self.assertIsNone(acc1.jvp(y))
|
||||||
|
self.assertIsNone(acc2.jvp(y))
|
||||||
|
|
||||||
def testJVPFunctionUsedByAccumulatorForOps(self):
|
def testJVPFunctionUsedByAccumulatorForOps(self):
|
||||||
previous_fn = forwardprop._jvp_dispatch
|
previous_fn = forwardprop._jvp_dispatch
|
||||||
try:
|
try:
|
||||||
|
@ -981,7 +981,7 @@ class _TapeGradientFunctions(object):
|
|||||||
self._func_graph.outputs,
|
self._func_graph.outputs,
|
||||||
forward_function_attr)
|
forward_function_attr)
|
||||||
|
|
||||||
if not self._func_graph.outputs or not input_tangents:
|
if not input_tangents:
|
||||||
# There is no need to special-case forwardprop, so we can return the
|
# There is no need to special-case forwardprop, so we can return the
|
||||||
# forward+backward pair we've created without further wrapping.
|
# forward+backward pair we've created without further wrapping.
|
||||||
return (forward_function, self._func_graph, backward_function,
|
return (forward_function, self._func_graph, backward_function,
|
||||||
@ -1085,6 +1085,11 @@ class _TapeGradientFunctions(object):
|
|||||||
"StatefulPartitionedCall": gradient_function}):
|
"StatefulPartitionedCall": gradient_function}):
|
||||||
forward_outputs = forward_function.call(context.context(),
|
forward_outputs = forward_function.call(context.context(),
|
||||||
forward_inputs)
|
forward_inputs)
|
||||||
|
if isinstance(forward_outputs, ops.Operation):
|
||||||
|
# _wrapped_backward_function expects a list, but if the function has
|
||||||
|
# no outputs its call() returns an Operation. We need to undo that
|
||||||
|
# so we don't cause problems later.
|
||||||
|
forward_outputs = []
|
||||||
py_backward, _ = self._wrap_backward_function(
|
py_backward, _ = self._wrap_backward_function(
|
||||||
self._func_graph, backward_function, forward_outputs)
|
self._func_graph, backward_function, forward_outputs)
|
||||||
# We will never request backward tape gradients for this operation
|
# We will never request backward tape gradients for this operation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user