Creating a ResourceTracker object and resource_tracker_scope that tracks

resources created. This is needed for the Estimator case since we're getting
rid of collections

PiperOrigin-RevId: 219889472
This commit is contained in:
Rohan Jain 2018-11-02 17:37:32 -07:00 committed by TensorFlower Gardener
parent 1eb896332f
commit 5f915f4dc5
2 changed files with 109 additions and 0 deletions

View File

@ -19,6 +19,11 @@ from __future__ import print_function
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.util import tf_contextlib
# global _RESOURCE_TRACKER_STACK
_RESOURCE_TRACKER_STACK = []
class NotCheckpointable(object):
@ -72,10 +77,57 @@ class Checkpointable(base.CheckpointableBase):
return data_structures.NoDependency(value)
class ResourceTracker(object):
"""An object that tracks a list of resources."""
def __init__(self):
self._resources = []
@property
def resources(self):
return self._resources
def add_resource(self, resource):
self._resources.append(resource)
@tf_contextlib.contextmanager
def resource_tracker_scope(resource_tracker):
"""A context to manage resource trackers.
Use this in order to collect up all resources created within a block of code.
Example usage:
```python
resource_tracker = ResourceTracker()
with resource_tracker_scope(resource_tracker):
resource = TrackableResource()
assert resource_tracker.resources == [resource]
Args:
resource_tracker: The passed in ResourceTracker object
Yields:
A scope in which the resource_tracker is active.
"""
global _RESOURCE_TRACKER_STACK
old = list(_RESOURCE_TRACKER_STACK)
_RESOURCE_TRACKER_STACK.append(resource_tracker)
try:
yield
finally:
_RESOURCE_TRACKER_STACK = old
class TrackableResource(base.CheckpointableBase):
"""Base class for all resources that need to be tracked."""
def __init__(self):
global _RESOURCE_TRACKER_STACK
for resource_tracker in _RESOURCE_TRACKER_STACK:
resource_tracker.add_resource(self)
self._resource_handle = None
def create_resource(self):

View File

@ -193,5 +193,62 @@ class InterfaceTests(test.TestCase):
self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
self.evaluate(a.tensors))
class _DummyResource(tracking.TrackableResource):
def __init__(self, handle_name):
self._handle_name = handle_name
super(_DummyResource, self).__init__()
def create_resource(self):
return self._handle_name
class ResourceTrackerTest(test.TestCase):
def testBasic(self):
resource_tracker = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker):
dummy_resource1 = _DummyResource("test1")
dummy_resource2 = _DummyResource("test2")
self.assertEqual(2, len(resource_tracker.resources))
self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
def testTwoScopes(self):
resource_tracker1 = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker1):
dummy_resource1 = _DummyResource("test1")
resource_tracker2 = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker2):
dummy_resource2 = _DummyResource("test2")
self.assertEqual(1, len(resource_tracker1.resources))
self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
self.assertEqual(1, len(resource_tracker1.resources))
self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
def testNestedScopesScopes(self):
resource_tracker = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker):
resource_tracker1 = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker1):
dummy_resource1 = _DummyResource("test1")
resource_tracker2 = tracking.ResourceTracker()
with tracking.resource_tracker_scope(resource_tracker2):
dummy_resource2 = _DummyResource("test2")
self.assertEqual(1, len(resource_tracker1.resources))
self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
self.assertEqual(1, len(resource_tracker1.resources))
self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
self.assertEqual(2, len(resource_tracker.resources))
self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
if __name__ == "__main__":
test.main()