diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 6282a6c4595..669fa084888 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -22,7 +22,6 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import core -from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,11 +34,6 @@ from tensorflow.python.ops import resource_variable_ops class Tests(test.TestCase): - def setUp(self): - # Force-load `distribution_strategy_context` to prevent GC at - # test time. See discussion in cl//219478951. - tape.distribution_strategy_context.get_distribution_strategy() - @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def testFastpathExecute_MatMulCorrectResponse(self): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 73f6682ec0b..fd55ad2af9e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -53,7 +53,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context -from tensorflow.python.eager import tape # pylint: disable=unused-import +from tensorflow.python.eager import tape from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -748,6 +748,10 @@ def assert_no_garbage_created(f): def decorator(self, **kwargs): """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" + # Force-load `distribution_strategy_context` to prevent GC at + # test time when using eager. Remove once b/117329403 is resolved. + tape.distribution_strategy_context.get_distribution_strategy() + gc.disable() previous_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL)