Forwardprop: opt the forwardprop utility function out of run_functions_eagerly
If the function did execute eagerly it would be very inefficient. Fixes #39075. PiperOrigin-RevId: 309992909 Change-Id: I3e31778390beb7a2808a33aa5fe18a5e9bd41bab
This commit is contained in:
parent
7dded88b95
commit
3e6697b916
@ -23,9 +23,9 @@ import threading
|
|||||||
from tensorflow.python import pywrap_tfe
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import backprop_util
|
from tensorflow.python.eager import backprop_util
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import forwardprop_util
|
from tensorflow.python.eager import forwardprop_util
|
||||||
|
from tensorflow.python.eager import function
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
@ -145,9 +145,15 @@ def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
|
|||||||
# implementations, or a more satisfying story about how we re-specialize
|
# implementations, or a more satisfying story about how we re-specialize
|
||||||
# gradients which were traced with relaxed shapes (e.g. use conds instead of
|
# gradients which were traced with relaxed shapes (e.g. use conds instead of
|
||||||
# trace-time Python logic).
|
# trace-time Python logic).
|
||||||
_jvp_relaxed_shapes = def_function.function(
|
#
|
||||||
|
# Using function.defun rather than def_function.function avoids
|
||||||
|
# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
|
||||||
|
# eagerly (infinite recursion), and even if it did it would use extra memory and
|
||||||
|
# run unnecessary computation. The function does not create variables, so the
|
||||||
|
# two symbols are otherwise equivalent.
|
||||||
|
_jvp_relaxed_shapes = function.defun(
|
||||||
_jvp_helper, experimental_relax_shapes=True)
|
_jvp_helper, experimental_relax_shapes=True)
|
||||||
_jvp_exact_shapes = def_function.function(
|
_jvp_exact_shapes = function.defun(
|
||||||
_jvp_helper, experimental_relax_shapes=False)
|
_jvp_helper, experimental_relax_shapes=False)
|
||||||
|
|
||||||
# The maximum number of exact-shape traces to perform for a single op before
|
# The maximum number of exact-shape traces to perform for a single op before
|
||||||
|
@ -235,6 +235,17 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIsNone(acc1.jvp(y))
|
self.assertIsNone(acc1.jvp(y))
|
||||||
self.assertIsNone(acc2.jvp(y))
|
self.assertIsNone(acc2.jvp(y))
|
||||||
|
|
||||||
|
def testRunFunctionsEagerly(self):
|
||||||
|
try:
|
||||||
|
original_setting = def_function.functions_run_eagerly()
|
||||||
|
def_function.run_functions_eagerly(True)
|
||||||
|
x = constant_op.constant(1.)
|
||||||
|
with forwardprop.ForwardAccumulator(x, 2.) as acc:
|
||||||
|
y = x * 3.
|
||||||
|
self.assertAllClose(6., acc.jvp(y))
|
||||||
|
finally:
|
||||||
|
def_function.run_functions_eagerly(original_setting)
|
||||||
|
|
||||||
def testJVPFunctionUsedByAccumulatorForOps(self):
|
def testJVPFunctionUsedByAccumulatorForOps(self):
|
||||||
previous_fn = forwardprop._jvp_dispatch
|
previous_fn = forwardprop._jvp_dispatch
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user