Make tf.test.TestCase respect superclass setUp and tearDown.

PiperOrigin-RevId: 336887640
Change-Id: Ibfb1c49e4ec2ac4ac313a81b01d30a81b9fd14c4
This commit is contained in:
A. Unique TensorFlower 2020-10-13 09:06:03 -07:00 committed by TensorFlower Gardener
parent 064e478a44
commit cf210c870e
2 changed files with 31 additions and 12 deletions
tensorflow/python/framework

View File

@ -120,8 +120,11 @@ except ImportError:
pass
def _get_object_count_by_type():
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
def _get_object_count_by_type(exclude=()):
return (
collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
collections.Counter([type(obj).__name__ for obj in exclude]))
@tf_export("test.gpu_device_name")
def gpu_device_name():
@ -657,12 +660,20 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
# versions of python2.7.x.
for _ in range(warmup_iters):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecylce, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
# Some objects are newly created by _get_object_count_by_type(). So
# create and save as a dummy variable to include it as a baseline.
obj_count_by_type = _get_object_count_by_type()
gc.collect()
obj_count_by_type = _get_object_count_by_type()
# unittest.doCleanups adds to self._outcome with each unwound call.
# These objects are retained across gc collections so we exclude them
# from the object count calculation.
obj_count_by_type = _get_object_count_by_type(
exclude=gc.get_referents(self._outcome.errors,
self._outcome.skipped))
if ops.has_default_graph():
collection_sizes_before = {
@ -671,6 +682,9 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
}
for _ in range(3):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecylce, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
@ -692,7 +706,11 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
gc.collect()
# There should be no new Python objects hanging around.
obj_count_by_type = _get_object_count_by_type() - obj_count_by_type
obj_count_by_type = (
_get_object_count_by_type(
exclude=gc.get_referents(self._outcome.errors,
self._outcome.skipped)) -
obj_count_by_type)
# In some cases (specifically on MacOS), new_count is somehow
# smaller than previous_count.
# Using plain assert because not all classes using this decorator
@ -2013,6 +2031,7 @@ class TensorFlowTestCase(googletest.TestCase):
self._test_start_time = None
def setUp(self):
super(TensorFlowTestCase, self).setUp()
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
@ -2047,6 +2066,7 @@ class TensorFlowTestCase(googletest.TestCase):
thread.check_termination()
self._ClearCachedSession()
super(TensorFlowTestCase, self).tearDown()
def _ClearCachedSession(self):
if self._cached_session is not None:

View File

@ -962,12 +962,13 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_new_objects_decorator(self):
class LeakedObjectTest(object):
class LeakedObjectTest(unittest.TestCase):
def __init__(inner_self): # pylint: disable=no-self-argument
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
inner_self.accumulation = []
def __init__(self, *args, **kwargs):
super(LeakedObjectTest, self).__init__(*args, **kwargs)
self.accumulation = []
@unittest.expectedFailure
@test_util.assert_no_new_pyobjects_executing_eagerly
def test_has_leak(self):
self.accumulation.append([1.])
@ -976,10 +977,8 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_has_no_leak(self):
self.not_accumulating = [1.]
with self.assertRaises(AssertionError):
LeakedObjectTest().test_has_leak()
LeakedObjectTest().test_has_no_leak()
self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,