Moved the workaround for circular import into assert_no_garbage_created

Currently, there is a circular dependency between

  eager/tape.py
  framework/ops.py
  training/distribution_strategy_context.py

which is resolved by adding lazy loaders in each of these modules.
However, lazy loading causes spurious allocations during eager test
runs and consecutively fails tests decorated with
@assert_no_garbage_collected.

The workaround could be removed once the lazy loading is removed.

PiperOrigin-RevId: 221065109
This commit is contained in:
Sergei Lebedev 2018-11-12 03:23:07 -08:00 committed by TensorFlower Gardener
parent 25409786f1
commit d8ba8138a2
2 changed files with 5 additions and 7 deletions

View File

@ -22,7 +22,6 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -35,11 +34,6 @@ from tensorflow.python.ops import resource_variable_ops
class Tests(test.TestCase):
def setUp(self):
# Force-load `distribution_strategy_context` to prevent GC at
# test time. See discussion in cl//219478951.
tape.distribution_strategy_context.get_distribution_strategy()
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
def testFastpathExecute_MatMulCorrectResponse(self):

View File

@ -53,7 +53,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import tape # pylint: disable=unused-import
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -748,6 +748,10 @@ def assert_no_garbage_created(f):
def decorator(self, **kwargs):
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
# Force-load `distribution_strategy_context` to prevent GC at
# test time when using eager. Remove once b/117329403 is resolved.
tape.distribution_strategy_context.get_distribution_strategy()
gc.disable()
previous_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)