Internal cleanup: Extract the function caching mechanism into a standalone utility module. A subsequent pass will enable this new implementation.

PiperOrigin-RevId: 304823326
Change-Id: I605467b489f189255132ab5441f1f46bfcb9e42c
This commit is contained in:
Dan Moldovan 2020-04-04 15:32:53 -07:00 committed by TensorFlower Gardener
parent 56244eb940
commit 360518666e
3 changed files with 192 additions and 0 deletions
tensorflow/python/autograph/pyct

View File

@ -24,6 +24,7 @@ py_library(
"__init__.py",
"anno.py",
"ast_util.py",
"cache.py",
"cfg.py",
"error_utils.py",
"errors.py",
@ -76,6 +77,21 @@ py_test(
],
)
py_test(
name = "cache_test",
srcs = ["cache_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss_py2",
],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
"@gast_archive//:gast",
],
)
py_test(
name = "cfg_test",
srcs = ["cfg_test.py"],

View File

@ -0,0 +1,97 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Caching utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import weakref
# TODO(mdan): Add a garbage collection hook for cleaning up modules.
class _TransformedFnCache(object):
"""Generic hierarchical cache for transformed functions.
The keys are soft references (i.e. they are discarded when the key is
destroyed) created from the source function by `_get_key`. The subkeys are
strong references and can be any value. Typically they identify different
kinds of transformation.
"""
__slots__ = ('_cache',)
def __init__(self):
self._cache = weakref.WeakKeyDictionary()
def _get_key(self, entity):
raise NotImplementedError('subclasses must override')
def has(self, entity, subkey):
key = self._get_key(entity)
parent = self._cache.get(key, None)
if parent is None:
return False
return subkey in parent
def __getitem__(self, entity):
key = self._get_key(entity)
parent = self._cache.get(key, None)
if parent is None:
# The bucket is initialized to support this usage:
# cache[key][subkey] = value
self._cache[key] = parent = {}
return parent
def __len__(self):
return len(self._cache)
class CodeObjectCache(_TransformedFnCache):
"""A function cache based on code objects.
Code objects are good proxies for the source code of a function.
This cache efficiently handles functions that share code objects, such as
functions defined in a loop, bound methods, etc.
The cache falls back to the function object, if it doesn't have a code object.
"""
def _get_key(self, entity):
if hasattr(entity, '__code__'):
return entity.__code__
else:
return entity
class UnboundInstanceCache(_TransformedFnCache):
"""A function cache based on unbound function objects.
Using the function for the cache key allows efficient handling of object
methods.
Unlike the _CodeObjectCache, this discriminates between different functions
even if they have the same code. This is needed for decorators that may
masquerade as another function.
"""
def _get_key(self, entity):
if inspect.ismethod(entity):
return entity.__func__
return entity

View File

@ -0,0 +1,79 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cache module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.platform import test
class CacheTest(test.TestCase):
def test_code_object_cache(self):
def factory(x):
def test_fn():
return x + 1
return test_fn
c = cache.CodeObjectCache()
f1 = factory(1)
dummy = object()
c[f1][1] = dummy
self.assertTrue(c.has(f1, 1))
self.assertFalse(c.has(f1, 2))
self.assertIs(c[f1][1], dummy)
self.assertEqual(len(c), 1)
f2 = factory(2)
self.assertTrue(c.has(f2, 1))
self.assertIs(c[f2][1], dummy)
self.assertEqual(len(c), 1)
def test_unbound_instance_cache(self):
class TestClass(object):
def method(self):
pass
c = cache.UnboundInstanceCache()
o1 = TestClass()
dummy = object()
c[o1.method][1] = dummy
self.assertTrue(c.has(o1.method, 1))
self.assertFalse(c.has(o1.method, 2))
self.assertIs(c[o1.method][1], dummy)
self.assertEqual(len(c), 1)
o2 = TestClass()
self.assertTrue(c.has(o2.method, 1))
self.assertIs(c[o2.method][1], dummy)
self.assertEqual(len(c), 1)
if __name__ == '__main__':
test.main()