Make tf.test.TestCase respect superclass setUp and tearDown.
PiperOrigin-RevId: 336887640 Change-Id: Ibfb1c49e4ec2ac4ac313a81b01d30a81b9fd14c4
This commit is contained in:
parent
064e478a44
commit
cf210c870e
tensorflow/python/framework
@ -120,8 +120,11 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _get_object_count_by_type():
|
||||
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
|
||||
def _get_object_count_by_type(exclude=()):
|
||||
return (
|
||||
collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
|
||||
collections.Counter([type(obj).__name__ for obj in exclude]))
|
||||
|
||||
|
||||
@tf_export("test.gpu_device_name")
|
||||
def gpu_device_name():
|
||||
@ -657,12 +660,20 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
# versions of python2.7.x.
|
||||
for _ in range(warmup_iters):
|
||||
f(self, *args, **kwargs)
|
||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||
# cleanups to clear out their object references.
|
||||
self.doCleanups()
|
||||
|
||||
# Some objects are newly created by _get_object_count_by_type(). So
|
||||
# create and save as a dummy variable to include it as a baseline.
|
||||
obj_count_by_type = _get_object_count_by_type()
|
||||
gc.collect()
|
||||
obj_count_by_type = _get_object_count_by_type()
|
||||
# unittest.doCleanups adds to self._outcome with each unwound call.
|
||||
# These objects are retained across gc collections so we exclude them
|
||||
# from the object count calculation.
|
||||
obj_count_by_type = _get_object_count_by_type(
|
||||
exclude=gc.get_referents(self._outcome.errors,
|
||||
self._outcome.skipped))
|
||||
|
||||
if ops.has_default_graph():
|
||||
collection_sizes_before = {
|
||||
@ -671,6 +682,9 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
}
|
||||
for _ in range(3):
|
||||
f(self, *args, **kwargs)
|
||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||
# cleanups to clear out their object references.
|
||||
self.doCleanups()
|
||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||
# collection (C types). Collections are a common source of leaks, so we
|
||||
# test for collection sizes explicitly.
|
||||
@ -692,7 +706,11 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
gc.collect()
|
||||
|
||||
# There should be no new Python objects hanging around.
|
||||
obj_count_by_type = _get_object_count_by_type() - obj_count_by_type
|
||||
obj_count_by_type = (
|
||||
_get_object_count_by_type(
|
||||
exclude=gc.get_referents(self._outcome.errors,
|
||||
self._outcome.skipped)) -
|
||||
obj_count_by_type)
|
||||
# In some cases (specifically on MacOS), new_count is somehow
|
||||
# smaller than previous_count.
|
||||
# Using plain assert because not all classes using this decorator
|
||||
@ -2013,6 +2031,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self._test_start_time = None
|
||||
|
||||
def setUp(self):
|
||||
super(TensorFlowTestCase, self).setUp()
|
||||
self._ClearCachedSession()
|
||||
random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
@ -2047,6 +2066,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
thread.check_termination()
|
||||
|
||||
self._ClearCachedSession()
|
||||
super(TensorFlowTestCase, self).tearDown()
|
||||
|
||||
def _ClearCachedSession(self):
|
||||
if self._cached_session is not None:
|
||||
|
@ -962,12 +962,13 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_no_new_objects_decorator(self):
|
||||
|
||||
class LeakedObjectTest(object):
|
||||
class LeakedObjectTest(unittest.TestCase):
|
||||
|
||||
def __init__(inner_self): # pylint: disable=no-self-argument
|
||||
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
|
||||
inner_self.accumulation = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LeakedObjectTest, self).__init__(*args, **kwargs)
|
||||
self.accumulation = []
|
||||
|
||||
@unittest.expectedFailure
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def test_has_leak(self):
|
||||
self.accumulation.append([1.])
|
||||
@ -976,10 +977,8 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
||||
def test_has_no_leak(self):
|
||||
self.not_accumulating = [1.]
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
LeakedObjectTest().test_has_leak()
|
||||
|
||||
LeakedObjectTest().test_has_no_leak()
|
||||
self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
|
||||
self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
|
||||
|
||||
|
||||
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
|
||||
|
Loading…
Reference in New Issue
Block a user