Forwardprop: opt the forwardprop utility function out of run_functions_eagerly
If the function did execute eagerly it would be very inefficient. Fixes #39075. PiperOrigin-RevId: 309992909 Change-Id: I3e31778390beb7a2808a33aa5fe18a5e9bd41bab
This commit is contained in:
parent
7dded88b95
commit
3e6697b916
@ -23,9 +23,9 @@ import threading
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import forwardprop_util
|
||||
from tensorflow.python.eager import function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
@ -145,9 +145,15 @@ def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
# implementations, or a more satisfying story about how we re-specialize
|
||||
# gradients which were traced with relaxed shapes (e.g. use conds instead of
|
||||
# trace-time Python logic).
|
||||
_jvp_relaxed_shapes = def_function.function(
|
||||
#
|
||||
# Using function.defun rather than def_function.function avoids
|
||||
# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
|
||||
# eagerly (infinite recursion), and even if it did it would use extra memory and
|
||||
# run unnecessary computation. The function does not create variables, so the
|
||||
# two symbols are otherwise equivalent.
|
||||
_jvp_relaxed_shapes = function.defun(
|
||||
_jvp_helper, experimental_relax_shapes=True)
|
||||
_jvp_exact_shapes = def_function.function(
|
||||
_jvp_exact_shapes = function.defun(
|
||||
_jvp_helper, experimental_relax_shapes=False)
|
||||
|
||||
# The maximum number of exact-shape traces to perform for a single op before
|
||||
|
@ -235,6 +235,17 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNone(acc1.jvp(y))
|
||||
self.assertIsNone(acc2.jvp(y))
|
||||
|
||||
def testRunFunctionsEagerly(self):
|
||||
try:
|
||||
original_setting = def_function.functions_run_eagerly()
|
||||
def_function.run_functions_eagerly(True)
|
||||
x = constant_op.constant(1.)
|
||||
with forwardprop.ForwardAccumulator(x, 2.) as acc:
|
||||
y = x * 3.
|
||||
self.assertAllClose(6., acc.jvp(y))
|
||||
finally:
|
||||
def_function.run_functions_eagerly(original_setting)
|
||||
|
||||
def testJVPFunctionUsedByAccumulatorForOps(self):
|
||||
previous_fn = forwardprop._jvp_dispatch
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user