Make tf.test.TestCase respect superclass setUp and tearDown.
PiperOrigin-RevId: 336887640 Change-Id: Ibfb1c49e4ec2ac4ac313a81b01d30a81b9fd14c4
This commit is contained in:
parent
064e478a44
commit
cf210c870e
@ -120,8 +120,11 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_object_count_by_type():
|
def _get_object_count_by_type(exclude=()):
|
||||||
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
|
return (
|
||||||
|
collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
|
||||||
|
collections.Counter([type(obj).__name__ for obj in exclude]))
|
||||||
|
|
||||||
|
|
||||||
@tf_export("test.gpu_device_name")
|
@tf_export("test.gpu_device_name")
|
||||||
def gpu_device_name():
|
def gpu_device_name():
|
||||||
@ -657,12 +660,20 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
# versions of python2.7.x.
|
# versions of python2.7.x.
|
||||||
for _ in range(warmup_iters):
|
for _ in range(warmup_iters):
|
||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
|
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||||
|
# cleanups to clear out their object references.
|
||||||
|
self.doCleanups()
|
||||||
|
|
||||||
# Some objects are newly created by _get_object_count_by_type(). So
|
# Some objects are newly created by _get_object_count_by_type(). So
|
||||||
# create and save as a dummy variable to include it as a baseline.
|
# create and save as a dummy variable to include it as a baseline.
|
||||||
obj_count_by_type = _get_object_count_by_type()
|
obj_count_by_type = _get_object_count_by_type()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
obj_count_by_type = _get_object_count_by_type()
|
# unittest.doCleanups adds to self._outcome with each unwound call.
|
||||||
|
# These objects are retained across gc collections so we exclude them
|
||||||
|
# from the object count calculation.
|
||||||
|
obj_count_by_type = _get_object_count_by_type(
|
||||||
|
exclude=gc.get_referents(self._outcome.errors,
|
||||||
|
self._outcome.skipped))
|
||||||
|
|
||||||
if ops.has_default_graph():
|
if ops.has_default_graph():
|
||||||
collection_sizes_before = {
|
collection_sizes_before = {
|
||||||
@ -671,6 +682,9 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
}
|
}
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
|
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||||
|
# cleanups to clear out their object references.
|
||||||
|
self.doCleanups()
|
||||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||||
# collection (C types). Collections are a common source of leaks, so we
|
# collection (C types). Collections are a common source of leaks, so we
|
||||||
# test for collection sizes explicitly.
|
# test for collection sizes explicitly.
|
||||||
@ -692,7 +706,11 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# There should be no new Python objects hanging around.
|
# There should be no new Python objects hanging around.
|
||||||
obj_count_by_type = _get_object_count_by_type() - obj_count_by_type
|
obj_count_by_type = (
|
||||||
|
_get_object_count_by_type(
|
||||||
|
exclude=gc.get_referents(self._outcome.errors,
|
||||||
|
self._outcome.skipped)) -
|
||||||
|
obj_count_by_type)
|
||||||
# In some cases (specifically on MacOS), new_count is somehow
|
# In some cases (specifically on MacOS), new_count is somehow
|
||||||
# smaller than previous_count.
|
# smaller than previous_count.
|
||||||
# Using plain assert because not all classes using this decorator
|
# Using plain assert because not all classes using this decorator
|
||||||
@ -2013,6 +2031,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
self._test_start_time = None
|
self._test_start_time = None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super(TensorFlowTestCase, self).setUp()
|
||||||
self._ClearCachedSession()
|
self._ClearCachedSession()
|
||||||
random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||||
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||||
@ -2047,6 +2066,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
thread.check_termination()
|
thread.check_termination()
|
||||||
|
|
||||||
self._ClearCachedSession()
|
self._ClearCachedSession()
|
||||||
|
super(TensorFlowTestCase, self).tearDown()
|
||||||
|
|
||||||
def _ClearCachedSession(self):
|
def _ClearCachedSession(self):
|
||||||
if self._cached_session is not None:
|
if self._cached_session is not None:
|
||||||
|
@ -962,12 +962,13 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def test_no_new_objects_decorator(self):
|
def test_no_new_objects_decorator(self):
|
||||||
|
|
||||||
class LeakedObjectTest(object):
|
class LeakedObjectTest(unittest.TestCase):
|
||||||
|
|
||||||
def __init__(inner_self): # pylint: disable=no-self-argument
|
def __init__(self, *args, **kwargs):
|
||||||
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
|
super(LeakedObjectTest, self).__init__(*args, **kwargs)
|
||||||
inner_self.accumulation = []
|
self.accumulation = []
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def test_has_leak(self):
|
def test_has_leak(self):
|
||||||
self.accumulation.append([1.])
|
self.accumulation.append([1.])
|
||||||
@ -976,10 +977,8 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
|||||||
def test_has_no_leak(self):
|
def test_has_no_leak(self):
|
||||||
self.not_accumulating = [1.]
|
self.not_accumulating = [1.]
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
|
||||||
LeakedObjectTest().test_has_leak()
|
self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
|
||||||
|
|
||||||
LeakedObjectTest().test_has_no_leak()
|
|
||||||
|
|
||||||
|
|
||||||
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
|
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
|
||||||
|
Loading…
Reference in New Issue
Block a user