From 10aff5d518393c73f2a068b069031d1bb2df0ec3 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 20 Feb 2020 16:02:32 -0800 Subject: [PATCH] Forwardprop: fix nested forwardprop of non-differentiable ops A special case was expecting that we didn't have any tangents, when in fact we just want to discard them. PiperOrigin-RevId: 296318800 Change-Id: I7198596435f294333e00a2dcfe5ac8ec31d0b28c --- tensorflow/python/eager/forwardprop_test.py | 8 ++++++++ tensorflow/python/eager/function.py | 7 ++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index fed04aec270..71473e51706 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -230,6 +230,14 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): )) self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp)) + def testNonDifferentiableOpWithInputTangent(self): + x = constant_op.constant(1.) + with forwardprop.ForwardAccumulator(x, 2.) as acc1: + with forwardprop.ForwardAccumulator(x, 2.) as acc2: + y = array_ops.zeros_like(x) + self.assertIsNone(acc1.jvp(y)) + self.assertIsNone(acc2.jvp(y)) + def testJVPFunctionUsedByAccumulatorForOps(self): previous_fn = forwardprop._jvp_dispatch try: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 895a5de7765..c16060422b8 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -981,7 +981,7 @@ class _TapeGradientFunctions(object): self._func_graph.outputs, forward_function_attr) - if not self._func_graph.outputs or not input_tangents: + if not input_tangents: # There is no need to special-case forwardprop, so we can return the # forward+backward pair we've created without further wrapping. return (forward_function, self._func_graph, backward_function, @@ -1085,6 +1085,11 @@ class _TapeGradientFunctions(object): "StatefulPartitionedCall": gradient_function}): forward_outputs = forward_function.call(context.context(), forward_inputs) + if isinstance(forward_outputs, ops.Operation): + # _wrapped_backward_function expects a list, but if the function has + # no outputs its call() returns an Operation. We need to undo that + # so we don't cause problems later. + forward_outputs = [] py_backward, _ = self._wrap_backward_function( self._func_graph, backward_function, forward_outputs) # We will never request backward tape gradients for this operation