Creating a ResourceTracker object and resource_tracker_scope that tracks
resources created. This is needed for the Estimator case since we're getting rid of collections PiperOrigin-RevId: 219889472
This commit is contained in:
parent
1eb896332f
commit
5f915f4dc5
@ -19,6 +19,11 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
# global _RESOURCE_TRACKER_STACK
|
||||
_RESOURCE_TRACKER_STACK = []
|
||||
|
||||
|
||||
class NotCheckpointable(object):
|
||||
@ -72,10 +77,57 @@ class Checkpointable(base.CheckpointableBase):
|
||||
return data_structures.NoDependency(value)
|
||||
|
||||
|
||||
class ResourceTracker(object):
|
||||
"""An object that tracks a list of resources."""
|
||||
|
||||
def __init__(self):
|
||||
self._resources = []
|
||||
|
||||
@property
|
||||
def resources(self):
|
||||
return self._resources
|
||||
|
||||
def add_resource(self, resource):
|
||||
self._resources.append(resource)
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def resource_tracker_scope(resource_tracker):
|
||||
"""A context to manage resource trackers.
|
||||
|
||||
Use this in order to collect up all resources created within a block of code.
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
resource_tracker = ResourceTracker()
|
||||
with resource_tracker_scope(resource_tracker):
|
||||
resource = TrackableResource()
|
||||
|
||||
assert resource_tracker.resources == [resource]
|
||||
|
||||
Args:
|
||||
resource_tracker: The passed in ResourceTracker object
|
||||
|
||||
Yields:
|
||||
A scope in which the resource_tracker is active.
|
||||
"""
|
||||
global _RESOURCE_TRACKER_STACK
|
||||
old = list(_RESOURCE_TRACKER_STACK)
|
||||
_RESOURCE_TRACKER_STACK.append(resource_tracker)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_RESOURCE_TRACKER_STACK = old
|
||||
|
||||
|
||||
class TrackableResource(base.CheckpointableBase):
|
||||
"""Base class for all resources that need to be tracked."""
|
||||
|
||||
def __init__(self):
|
||||
global _RESOURCE_TRACKER_STACK
|
||||
for resource_tracker in _RESOURCE_TRACKER_STACK:
|
||||
resource_tracker.add_resource(self)
|
||||
|
||||
self._resource_handle = None
|
||||
|
||||
def create_resource(self):
|
||||
|
||||
@ -193,5 +193,62 @@ class InterfaceTests(test.TestCase):
|
||||
self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
|
||||
self.evaluate(a.tensors))
|
||||
|
||||
|
||||
class _DummyResource(tracking.TrackableResource):
|
||||
|
||||
def __init__(self, handle_name):
|
||||
self._handle_name = handle_name
|
||||
super(_DummyResource, self).__init__()
|
||||
|
||||
def create_resource(self):
|
||||
return self._handle_name
|
||||
|
||||
|
||||
class ResourceTrackerTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
resource_tracker = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker):
|
||||
dummy_resource1 = _DummyResource("test1")
|
||||
dummy_resource2 = _DummyResource("test2")
|
||||
|
||||
self.assertEqual(2, len(resource_tracker.resources))
|
||||
self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
|
||||
self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
|
||||
|
||||
def testTwoScopes(self):
|
||||
resource_tracker1 = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker1):
|
||||
dummy_resource1 = _DummyResource("test1")
|
||||
|
||||
resource_tracker2 = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker2):
|
||||
dummy_resource2 = _DummyResource("test2")
|
||||
|
||||
self.assertEqual(1, len(resource_tracker1.resources))
|
||||
self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
|
||||
self.assertEqual(1, len(resource_tracker1.resources))
|
||||
self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
|
||||
|
||||
def testNestedScopesScopes(self):
|
||||
resource_tracker = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker):
|
||||
resource_tracker1 = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker1):
|
||||
dummy_resource1 = _DummyResource("test1")
|
||||
|
||||
resource_tracker2 = tracking.ResourceTracker()
|
||||
with tracking.resource_tracker_scope(resource_tracker2):
|
||||
dummy_resource2 = _DummyResource("test2")
|
||||
|
||||
self.assertEqual(1, len(resource_tracker1.resources))
|
||||
self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
|
||||
self.assertEqual(1, len(resource_tracker1.resources))
|
||||
self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
|
||||
self.assertEqual(2, len(resource_tracker.resources))
|
||||
self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
|
||||
self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user