Moved the workaround for circular import into assert_no_garbage_created
Currently, there is a circular dependency between eager/tape.py framework/ops.py training/distribution_strategy_context.py which is resolved by adding lazy loaders in each of these modules. However, lazy loading causes spurious allocations during eager test runs and consecutively fails tests decorated with @assert_no_garbage_collected. The workaround could be removed once the lazy loading is removed. PiperOrigin-RevId: 221065109
This commit is contained in:
parent
25409786f1
commit
d8ba8138a2
@ -22,7 +22,6 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,11 +34,6 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
class Tests(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Force-load `distribution_strategy_context` to prevent GC at
|
||||
# test time. See discussion in cl//219478951.
|
||||
tape.distribution_strategy_context.get_distribution_strategy()
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
def testFastpathExecute_MatMulCorrectResponse(self):
|
||||
|
@ -53,7 +53,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape # pylint: disable=unused-import
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -748,6 +748,10 @@ def assert_no_garbage_created(f):
|
||||
|
||||
def decorator(self, **kwargs):
|
||||
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
|
||||
# Force-load `distribution_strategy_context` to prevent GC at
|
||||
# test time when using eager. Remove once b/117329403 is resolved.
|
||||
tape.distribution_strategy_context.get_distribution_strategy()
|
||||
|
||||
gc.disable()
|
||||
previous_debug_flags = gc.get_debug()
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
|
Loading…
Reference in New Issue
Block a user