From d8ba8138a221dce01fef830f6cb58b1a008bbbc5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 12 Nov 2018 03:23:07 -0800 Subject: [PATCH] Moved the workaround for circular import into assert_no_garbage_created Currently, there is a circular dependency between eager/tape.py framework/ops.py training/distribution_strategy_context.py which is resolved by adding lazy loaders in each of these modules. However, lazy loading causes spurious allocations during eager test runs and consecutively fails tests decorated with @assert_no_garbage_collected. The workaround could be removed once the lazy loading is removed. PiperOrigin-RevId: 221065109 --- tensorflow/python/eager/pywrap_tfe_test.py | 6 ------ tensorflow/python/framework/test_util.py | 6 +++++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 6282a6c4595..669fa084888 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -22,7 +22,6 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import core -from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,11 +34,6 @@ from tensorflow.python.ops import resource_variable_ops class Tests(test.TestCase): - def setUp(self): - # Force-load `distribution_strategy_context` to prevent GC at - # test time. See discussion in cl//219478951. - tape.distribution_strategy_context.get_distribution_strategy() - @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def testFastpathExecute_MatMulCorrectResponse(self): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 73f6682ec0b..fd55ad2af9e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -53,7 +53,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context -from tensorflow.python.eager import tape # pylint: disable=unused-import +from tensorflow.python.eager import tape from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -748,6 +748,10 @@ def assert_no_garbage_created(f): def decorator(self, **kwargs): """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" + # Force-load `distribution_strategy_context` to prevent GC at + # test time when using eager. Remove once b/117329403 is resolved. + tape.distribution_strategy_context.get_distribution_strategy() + gc.disable() previous_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL)