Forwardprop: opt the forwardprop utility function out of run_functions_eagerly

If the function did execute eagerly it would be very inefficient.

Fixes #39075.

PiperOrigin-RevId: 309992909
Change-Id: I3e31778390beb7a2808a33aa5fe18a5e9bd41bab
This commit is contained in:
Allen Lavoie 2020-05-05 12:02:04 -07:00 committed by TensorFlower Gardener
parent 7dded88b95
commit 3e6697b916
2 changed files with 20 additions and 3 deletions

View File

@ -23,9 +23,9 @@ import threading
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
@ -145,9 +145,15 @@ def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
# implementations, or a more satisfying story about how we re-specialize
# gradients which were traced with relaxed shapes (e.g. use conds instead of
# trace-time Python logic).
_jvp_relaxed_shapes = def_function.function(
#
# Using function.defun rather than def_function.function avoids
# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
# eagerly (infinite recursion), and even if it did it would use extra memory and
# run unnecessary computation. The function does not create variables, so the
# two symbols are otherwise equivalent.
_jvp_relaxed_shapes = function.defun(
_jvp_helper, experimental_relax_shapes=True)
_jvp_exact_shapes = def_function.function(
_jvp_exact_shapes = function.defun(
_jvp_helper, experimental_relax_shapes=False)
# The maximum number of exact-shape traces to perform for a single op before

View File

@ -235,6 +235,17 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(acc1.jvp(y))
self.assertIsNone(acc2.jvp(y))
def testRunFunctionsEagerly(self):
try:
original_setting = def_function.functions_run_eagerly()
def_function.run_functions_eagerly(True)
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc:
y = x * 3.
self.assertAllClose(6., acc.jvp(y))
finally:
def_function.run_functions_eagerly(original_setting)
def testJVPFunctionUsedByAccumulatorForOps(self):
previous_fn = forwardprop._jvp_dispatch
try: