diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 7881b17f88b..735d504f18f 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -24,6 +24,7 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cache.py", "cfg.py", "error_utils.py", "errors.py", @@ -76,6 +77,21 @@ py_test( ], ) +py_test( + name = "cache_test", + srcs = ["cache_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_oss_py2", + ], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "cfg_test", srcs = ["cfg_test.py"], diff --git a/tensorflow/python/autograph/pyct/cache.py b/tensorflow/python/autograph/pyct/cache.py new file mode 100644 index 00000000000..d9af6e6156a --- /dev/null +++ b/tensorflow/python/autograph/pyct/cache.py @@ -0,0 +1,97 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Caching utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import weakref + + +# TODO(mdan): Add a garbage collection hook for cleaning up modules. +class _TransformedFnCache(object): + """Generic hierarchical cache for transformed functions. + + The keys are soft references (i.e. they are discarded when the key is + destroyed) created from the source function by `_get_key`. The subkeys are + strong references and can be any value. Typically they identify different + kinds of transformation. + """ + + __slots__ = ('_cache',) + + def __init__(self): + self._cache = weakref.WeakKeyDictionary() + + def _get_key(self, entity): + raise NotImplementedError('subclasses must override') + + def has(self, entity, subkey): + key = self._get_key(entity) + parent = self._cache.get(key, None) + if parent is None: + return False + return subkey in parent + + def __getitem__(self, entity): + key = self._get_key(entity) + parent = self._cache.get(key, None) + if parent is None: + # The bucket is initialized to support this usage: + # cache[key][subkey] = value + self._cache[key] = parent = {} + return parent + + def __len__(self): + return len(self._cache) + + +class CodeObjectCache(_TransformedFnCache): + """A function cache based on code objects. + + Code objects are good proxies for the source code of a function. + + This cache efficiently handles functions that share code objects, such as + functions defined in a loop, bound methods, etc. + + The cache falls back to the function object, if it doesn't have a code object. + """ + + def _get_key(self, entity): + if hasattr(entity, '__code__'): + return entity.__code__ + else: + return entity + + +class UnboundInstanceCache(_TransformedFnCache): + """A function cache based on unbound function objects. + + Using the function for the cache key allows efficient handling of object + methods. + + Unlike the _CodeObjectCache, this discriminates between different functions + even if they have the same code. This is needed for decorators that may + masquerade as another function. + """ + + def _get_key(self, entity): + if inspect.ismethod(entity): + return entity.__func__ + return entity + + diff --git a/tensorflow/python/autograph/pyct/cache_test.py b/tensorflow/python/autograph/pyct/cache_test.py new file mode 100644 index 00000000000..6c40954be56 --- /dev/null +++ b/tensorflow/python/autograph/pyct/cache_test.py @@ -0,0 +1,79 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cache module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.pyct import cache +from tensorflow.python.platform import test + + +class CacheTest(test.TestCase): + + def test_code_object_cache(self): + + def factory(x): + def test_fn(): + return x + 1 + return test_fn + + c = cache.CodeObjectCache() + + f1 = factory(1) + dummy = object() + + c[f1][1] = dummy + + self.assertTrue(c.has(f1, 1)) + self.assertFalse(c.has(f1, 2)) + self.assertIs(c[f1][1], dummy) + self.assertEqual(len(c), 1) + + f2 = factory(2) + + self.assertTrue(c.has(f2, 1)) + self.assertIs(c[f2][1], dummy) + self.assertEqual(len(c), 1) + + def test_unbound_instance_cache(self): + + class TestClass(object): + + def method(self): + pass + + c = cache.UnboundInstanceCache() + + o1 = TestClass() + dummy = object() + + c[o1.method][1] = dummy + + self.assertTrue(c.has(o1.method, 1)) + self.assertFalse(c.has(o1.method, 2)) + self.assertIs(c[o1.method][1], dummy) + self.assertEqual(len(c), 1) + + o2 = TestClass() + + self.assertTrue(c.has(o2.method, 1)) + self.assertIs(c[o2.method][1], dummy) + self.assertEqual(len(c), 1) + + +if __name__ == '__main__': + test.main()