Moved the workaround for circular import into assert_no_garbage_created
Currently, there is a circular dependency between eager/tape.py framework/ops.py training/distribution_strategy_context.py which is resolved by adding lazy loaders in each of these modules. However, lazy loading causes spurious allocations during eager test runs and consecutively fails tests decorated with @assert_no_garbage_collected. The workaround could be removed once the lazy loading is removed. PiperOrigin-RevId: 221065109
This commit is contained in:
parent
25409786f1
commit
d8ba8138a2
@ -22,7 +22,6 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import tape
|
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -35,11 +34,6 @@ from tensorflow.python.ops import resource_variable_ops
|
|||||||
|
|
||||||
class Tests(test.TestCase):
|
class Tests(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# Force-load `distribution_strategy_context` to prevent GC at
|
|
||||||
# test time. See discussion in cl//219478951.
|
|
||||||
tape.distribution_strategy_context.get_distribution_strategy()
|
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
def testFastpathExecute_MatMulCorrectResponse(self):
|
def testFastpathExecute_MatMulCorrectResponse(self):
|
||||||
|
@ -53,7 +53,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape # pylint: disable=unused-import
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -748,6 +748,10 @@ def assert_no_garbage_created(f):
|
|||||||
|
|
||||||
def decorator(self, **kwargs):
|
def decorator(self, **kwargs):
|
||||||
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
|
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
|
||||||
|
# Force-load `distribution_strategy_context` to prevent GC at
|
||||||
|
# test time when using eager. Remove once b/117329403 is resolved.
|
||||||
|
tape.distribution_strategy_context.get_distribution_strategy()
|
||||||
|
|
||||||
gc.disable()
|
gc.disable()
|
||||||
previous_debug_flags = gc.get_debug()
|
previous_debug_flags = gc.get_debug()
|
||||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||||
|
Loading…
Reference in New Issue
Block a user