Internal cleanup: Extract the function caching mechanism into a standalone utility module. A subsequent pass will enable this new implementation.
PiperOrigin-RevId: 304823326 Change-Id: I605467b489f189255132ab5441f1f46bfcb9e42c
This commit is contained in:
parent
56244eb940
commit
360518666e
tensorflow/python/autograph/pyct
@ -24,6 +24,7 @@ py_library(
|
||||
"__init__.py",
|
||||
"anno.py",
|
||||
"ast_util.py",
|
||||
"cache.py",
|
||||
"cfg.py",
|
||||
"error_utils.py",
|
||||
"errors.py",
|
||||
@ -76,6 +77,21 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cache_test",
|
||||
srcs = ["cache_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
],
|
||||
deps = [
|
||||
":pyct",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cfg_test",
|
||||
srcs = ["cfg_test.py"],
|
||||
|
97
tensorflow/python/autograph/pyct/cache.py
Normal file
97
tensorflow/python/autograph/pyct/cache.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Caching utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import weakref
|
||||
|
||||
|
||||
# TODO(mdan): Add a garbage collection hook for cleaning up modules.
|
||||
class _TransformedFnCache(object):
|
||||
"""Generic hierarchical cache for transformed functions.
|
||||
|
||||
The keys are soft references (i.e. they are discarded when the key is
|
||||
destroyed) created from the source function by `_get_key`. The subkeys are
|
||||
strong references and can be any value. Typically they identify different
|
||||
kinds of transformation.
|
||||
"""
|
||||
|
||||
__slots__ = ('_cache',)
|
||||
|
||||
def __init__(self):
|
||||
self._cache = weakref.WeakKeyDictionary()
|
||||
|
||||
def _get_key(self, entity):
|
||||
raise NotImplementedError('subclasses must override')
|
||||
|
||||
def has(self, entity, subkey):
|
||||
key = self._get_key(entity)
|
||||
parent = self._cache.get(key, None)
|
||||
if parent is None:
|
||||
return False
|
||||
return subkey in parent
|
||||
|
||||
def __getitem__(self, entity):
|
||||
key = self._get_key(entity)
|
||||
parent = self._cache.get(key, None)
|
||||
if parent is None:
|
||||
# The bucket is initialized to support this usage:
|
||||
# cache[key][subkey] = value
|
||||
self._cache[key] = parent = {}
|
||||
return parent
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class CodeObjectCache(_TransformedFnCache):
|
||||
"""A function cache based on code objects.
|
||||
|
||||
Code objects are good proxies for the source code of a function.
|
||||
|
||||
This cache efficiently handles functions that share code objects, such as
|
||||
functions defined in a loop, bound methods, etc.
|
||||
|
||||
The cache falls back to the function object, if it doesn't have a code object.
|
||||
"""
|
||||
|
||||
def _get_key(self, entity):
|
||||
if hasattr(entity, '__code__'):
|
||||
return entity.__code__
|
||||
else:
|
||||
return entity
|
||||
|
||||
|
||||
class UnboundInstanceCache(_TransformedFnCache):
|
||||
"""A function cache based on unbound function objects.
|
||||
|
||||
Using the function for the cache key allows efficient handling of object
|
||||
methods.
|
||||
|
||||
Unlike the _CodeObjectCache, this discriminates between different functions
|
||||
even if they have the same code. This is needed for decorators that may
|
||||
masquerade as another function.
|
||||
"""
|
||||
|
||||
def _get_key(self, entity):
|
||||
if inspect.ismethod(entity):
|
||||
return entity.__func__
|
||||
return entity
|
||||
|
||||
|
79
tensorflow/python/autograph/pyct/cache_test.py
Normal file
79
tensorflow/python/autograph/pyct/cache_test.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for cache module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CacheTest(test.TestCase):
|
||||
|
||||
def test_code_object_cache(self):
|
||||
|
||||
def factory(x):
|
||||
def test_fn():
|
||||
return x + 1
|
||||
return test_fn
|
||||
|
||||
c = cache.CodeObjectCache()
|
||||
|
||||
f1 = factory(1)
|
||||
dummy = object()
|
||||
|
||||
c[f1][1] = dummy
|
||||
|
||||
self.assertTrue(c.has(f1, 1))
|
||||
self.assertFalse(c.has(f1, 2))
|
||||
self.assertIs(c[f1][1], dummy)
|
||||
self.assertEqual(len(c), 1)
|
||||
|
||||
f2 = factory(2)
|
||||
|
||||
self.assertTrue(c.has(f2, 1))
|
||||
self.assertIs(c[f2][1], dummy)
|
||||
self.assertEqual(len(c), 1)
|
||||
|
||||
def test_unbound_instance_cache(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
c = cache.UnboundInstanceCache()
|
||||
|
||||
o1 = TestClass()
|
||||
dummy = object()
|
||||
|
||||
c[o1.method][1] = dummy
|
||||
|
||||
self.assertTrue(c.has(o1.method, 1))
|
||||
self.assertFalse(c.has(o1.method, 2))
|
||||
self.assertIs(c[o1.method][1], dummy)
|
||||
self.assertEqual(len(c), 1)
|
||||
|
||||
o2 = TestClass()
|
||||
|
||||
self.assertTrue(c.has(o2.method, 1))
|
||||
self.assertIs(c[o2.method][1], dummy)
|
||||
self.assertEqual(len(c), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user