Move cached_per_instance to keras utils since its only used in Keras.

PiperOrigin-RevId: 326677839
Change-Id: I04fb71d17241b65fc1d5ae8f69e4d40770357bf7
This commit is contained in:
Scott Zhu 2020-08-14 10:09:39 -07:00 committed by TensorFlower Gardener
parent f4bae1839c
commit 57e69437b4
9 changed files with 296 additions and 247 deletions

View File

@ -2926,14 +2926,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._call_accepts_kwargs)
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_full_argspec(self):
# Argspec inspection is expensive and the call spec is used often, so it
# makes sense to cache the result.
return tf_inspect.getfullargspec(self.call)
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_fn_args(self):
all_args = self._call_full_argspec.args
# Scrub `self` that appears if a decorator was applied.
@ -2942,7 +2942,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return all_args
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_fn_arg_defaults(self):
call_fn_args = self._call_fn_args
call_fn_defaults = self._call_full_argspec.defaults or []
@ -2955,7 +2955,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return defaults
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_fn_arg_positions(self):
call_fn_arg_positions = dict()
for pos, arg in enumerate(self._call_fn_args):
@ -2963,7 +2963,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return call_fn_arg_positions
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_accepts_kwargs(self):
return self._call_full_argspec.varkw is not None

View File

@ -2342,14 +2342,14 @@ class Layer(base_layer.Layer):
self._call_accepts_kwargs)
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_full_argspec(self):
# Argspec inspection is expensive and the call spec is used often, so it
# makes sense to cache the result.
return tf_inspect.getfullargspec(self.call)
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_fn_args(self):
all_args = self._call_full_argspec.args
# Scrub `self` that appears if a decorator was applied.
@ -2358,7 +2358,7 @@ class Layer(base_layer.Layer):
return all_args
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_fn_arg_positions(self):
call_fn_arg_positions = dict()
for pos, arg in enumerate(self._call_fn_args):
@ -2366,12 +2366,12 @@ class Layer(base_layer.Layer):
return call_fn_arg_positions
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _call_accepts_kwargs(self):
return self._call_full_argspec.varkw is not None
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _should_compute_mask(self):
return ('mask' in self._call_fn_args or
getattr(self, 'compute_mask', None) is not None)

View File

@ -49,6 +49,7 @@ py_library(
"//tensorflow/python/keras:backend_config",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras/engine:base_layer_utils",
"//tensorflow/python/keras/utils:layer_utils",
"//tensorflow/python/keras/utils:tf_utils",
],
)

View File

@ -39,6 +39,7 @@ from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -48,7 +49,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -1207,12 +1207,12 @@ class OptimizerV2(trackable.Trackable):
return x.value()
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _dense_apply_args(self):
return tf_inspect.getfullargspec(self._resource_apply_dense).args
@property
@tracking.cached_per_instance
@layer_utils.cached_per_instance
def _sparse_apply_args(self):
return tf_inspect.getfullargspec(self._resource_apply_sparse).args

View File

@ -301,6 +301,19 @@ tf_py_test(
],
)
tf_py_test(
name = "layer_utils_test",
size = "small",
srcs = ["layer_utils_test.py"],
python_version = "PY3",
deps = [
":layer_utils",
"//tensorflow/python:client_testlib",
"//tensorflow/python/training/tracking",
"//third_party/py/numpy",
],
)
tf_py_test(
name = "np_utils_test",
size = "small",

View File

@ -19,6 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import weakref
import numpy as np
import six
@ -404,3 +407,99 @@ def is_builtin_layer(layer):
# of the base layer class.
return (layer._keras_api_names != ('keras.layers.Layer',) and
layer._keras_api_names_v1 != ('keras.layers.Layer',))
def cached_per_instance(f):
"""Lightweight decorator for caching lazily constructed properties.
When to use:
This decorator provides simple caching with minimal overhead. It is designed
for properties which are expensive to compute and static over the life of a
class instance, and provides no mechanism for cache invalidation. Thus it is
best suited for lazily exposing derived properties of other static data.
For classes with custom getattr / setattr behavior (such as trackable
objects), storing cache results as object attributes is not performant.
Instead, a specialized cache can significantly reduce property lookup
overhead. (While still allowing the decorated property to be lazily computed.)
Consider the following class:
```
class MyClass(object):
def __setattr__(self, key, value):
# Some expensive class specific code
# ...
# ...
super(MyClass, self).__setattr__(key, value)
@property
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
output = getattr(self, '_thing', None)
if output is None:
self._thing = output = compute_thing(self)
return output
```
It's also worth noting that ANY overriding of __setattr__, even something as
simple as:
```
def __setattr__(self, key, value):
super(MyClass, self).__setattr__(key, value)
```
Slows down attribute assignment by nearly 10x.
By contrast, replacing the definition of `thing` with the following sidesteps
the expensive __setattr__ altogether:
'''
@property
@tracking.cached_per_instance
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
return compute_thing(self)
'''
Performance:
The overhead for this decorator is ~0.4 us / call. A much lower overhead
implementation (~0.085 us / call) can be achieved by using a custom dict type:
```
def dict_based_cache(f):
class Cache(dict):
__slots__ = ()
def __missing__(self, key):
self[key] = output = f(key)
return output
return property(Cache().__getitem__)
```
However, that implementation holds class instances as keys, and as a result
blocks garbage collection. (And modifying it to use weakref's as keys raises
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
implementation below turns out to be more prudent.
Args:
f: The function to cache.
Returns:
f decorated with simple caching behavior.
"""
cache = weakref.WeakKeyDictionary()
@functools.wraps(f)
def wrapped(item):
output = cache.get(item)
if output is None:
cache[item] = output = f(item)
return output
wrapped.cache = cache
return wrapped

View File

@ -0,0 +1,170 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for layer_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import multiprocessing.dummy
import pickle
import time
import timeit
import numpy as np
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import tracking
_PICKLEABLE_CALL_COUNT = collections.Counter()
class MyPickleableObject(tracking.AutoTrackable):
"""Needed for InterfaceTests.test_property_cache_serialization.
This class must be at the top level. This is a constraint of pickle,
unrelated to `cached_per_instance`.
"""
@property
@layer_utils.cached_per_instance
def my_id(self):
_PICKLEABLE_CALL_COUNT[self] += 1
return id(self)
class LayerUtilsTest(test.TestCase):
def test_property_cache(self):
test_counter = collections.Counter()
class MyObject(tracking.AutoTrackable):
def __init__(self):
super(MyObject, self).__init__()
self._frozen = True
def __setattr__(self, key, value):
"""Enforce that cache does not set attribute on MyObject."""
if getattr(self, "_frozen", False):
raise ValueError("Cannot mutate when frozen.")
return super(MyObject, self).__setattr__(key, value)
@property
@layer_utils.cached_per_instance
def test_property(self):
test_counter[id(self)] += 1
return id(self)
first_object = MyObject()
second_object = MyObject()
# Make sure the objects return the correct values
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Make sure the cache does not share across objects
self.assertNotEqual(first_object.test_property, second_object.test_property)
# Check again (Now the values should be cached.)
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Count the function calls to make sure the cache is actually being used.
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
def test_property_cache_threaded(self):
call_count = collections.Counter()
class MyObject(tracking.AutoTrackable):
@property
@layer_utils.cached_per_instance
def test_property(self):
# Random sleeps to ensure that the execution thread changes
# mid-computation.
call_count["test_property"] += 1
time.sleep(np.random.random() + 1.)
# Use a RandomState which is seeded off the instance's id (the mod is
# because numpy limits the range of seeds) to ensure that an instance
# returns the same value in different threads, but different instances
# return different values.
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
def get_test_property(self, _):
"""Function provided to .map for threading test."""
return self.test_property
# Test that multiple threads return the same value. This requires that
# the underlying function is repeatable, as cached_property makes no attempt
# to prioritize the first call.
test_obj = MyObject()
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
# Intentionally make a large pool (even when there are only a small number
# of cpus) to ensure that the runtime switches threads.
results = pool.map(test_obj.get_test_property, range(64))
self.assertEqual(len(set(results)), 1)
# Make sure we actually are testing threaded behavior.
self.assertGreater(call_count["test_property"], 1)
# Make sure new threads still cache hit.
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
start_time = timeit.default_timer() # Don't time pool instantiation.
results = pool.map(test_obj.get_test_property, range(4))
total_time = timeit.default_timer() - start_time
# Note(taylorrobie): The reason that it is safe to time a unit test is that
# a cache hit will be << 1 second, and a cache miss is
# guaranteed to be >= 1 second. Empirically confirmed by
# 100,000 runs with no flakes.
self.assertLess(total_time, 0.95)
def test_property_cache_serialization(self):
# Reset call count. .keys() must be wrapped in a list, because otherwise we
# would mutate the iterator while iterating.
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
_PICKLEABLE_CALL_COUNT.pop(k)
first_instance = MyPickleableObject()
self.assertEqual(id(first_instance), first_instance.my_id)
# Test that we can pickle and un-pickle
second_instance = pickle.loads(pickle.dumps(first_instance))
self.assertEqual(id(second_instance), second_instance.my_id)
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
# Make sure de-serialized object uses the cache.
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
# Make sure the decorator cache is not being serialized with the object.
expected_size = len(pickle.dumps(second_instance))
for _ in range(5):
# Add some more entries to the cache.
_ = MyPickleableObject().my_id
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
size_check_instance = MyPickleableObject()
_ = size_check_instance.my_id
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
if __name__ == "__main__":
test.main()

View File

@ -18,8 +18,6 @@ from __future__ import division
from __future__ import print_function
import copy
import functools
import weakref
from absl import logging
@ -357,100 +355,5 @@ class Asset(base.Trackable):
return self._path
def cached_per_instance(f):
"""Lightweight decorator for caching lazily constructed properties.
When to use:
This decorator provides simple caching with minimal overhead. It is designed
for properties which are expensive to compute and static over the life of a
class instance, and provides no mechanism for cache invalidation. Thus it is
best suited for lazily exposing derived properties of other static data.
For classes with custom getattr / setattr behavior (such as trackable
objects), storing cache results as object attributes is not performant.
Instead, a specialized cache can significantly reduce property lookup
overhead. (While still allowing the decorated property to be lazily computed.)
Consider the following class:
```
class MyClass(object):
def __setattr__(self, key, value):
# Some expensive class specific code
# ...
# ...
super(MyClass, self).__setattr__(key, value)
@property
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
output = getattr(self, '_thing', None)
if output is None:
self._thing = output = compute_thing(self)
return output
```
It's also worth noting that ANY overriding of __setattr__, even something as
simple as:
```
def __setattr__(self, key, value):
super(MyClass, self).__setattr__(key, value)
```
Slows down attribute assignment by nearly 10x.
By contrast, replacing the definition of `thing` with the following sidesteps
the expensive __setattr__ altogether:
'''
@property
@tracking.cached_per_instance
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
return compute_thing(self)
'''
Performance:
The overhead for this decorator is ~0.4 us / call. A much lower overhead
implementation (~0.085 us / call) can be achieved by using a custom dict type:
```
def dict_based_cache(f):
class Cache(dict):
__slots__ = ()
def __missing__(self, key):
self[key] = output = f(key)
return output
return property(Cache().__getitem__)
```
However, that implementation holds class instances as keys, and as a result
blocks garbage collection. (And modifying it to use weakref's as keys raises
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
implementation below turns out to be more prudent.
Args:
f: The function to cache.
Returns:
f decorated with simple caching behavior.
"""
cache = weakref.WeakKeyDictionary()
@functools.wraps(f)
def wrapped(item):
output = cache.get(item)
if output is None:
cache[item] = output = f(item)
return output
wrapped.cache = cache
return wrapped
ops.register_tensor_conversion_function(
Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw))

View File

@ -16,13 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import multiprocessing.dummy
import os
import pickle
import time
import timeit
import numpy as np
@ -35,23 +29,6 @@ from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
_PICKLEABLE_CALL_COUNT = collections.Counter()
class MyPickleableObject(tracking.AutoTrackable):
"""Needed for InterfaceTests.test_property_cache_serialization.
This class must be at the top level. This is a constraint of pickle,
unrelated to `cached_per_instance`.
"""
@property
@tracking.cached_per_instance
def my_id(self):
_PICKLEABLE_CALL_COUNT[self] += 1
return id(self)
class InterfaceTests(test.TestCase):
def testMultipleAssignment(self):
@ -169,120 +146,6 @@ class InterfaceTests(test.TestCase):
self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
self.evaluate(a.tensors))
def test_property_cache(self):
test_counter = collections.Counter()
class MyObject(tracking.AutoTrackable):
def __init__(self):
super(MyObject, self).__init__()
self._frozen = True
def __setattr__(self, key, value):
"""Enforce that cache does not set attribute on MyObject."""
if getattr(self, "_frozen", False):
raise ValueError("Cannot mutate when frozen.")
return super(MyObject, self).__setattr__(key, value)
@property
@tracking.cached_per_instance
def test_property(self):
test_counter[id(self)] += 1
return id(self)
first_object = MyObject()
second_object = MyObject()
# Make sure the objects return the correct values
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Make sure the cache does not share across objects
self.assertNotEqual(first_object.test_property, second_object.test_property)
# Check again (Now the values should be cached.)
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Count the function calls to make sure the cache is actually being used.
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
def test_property_cache_threaded(self):
call_count = collections.Counter()
class MyObject(tracking.AutoTrackable):
@property
@tracking.cached_per_instance
def test_property(self):
# Random sleeps to ensure that the execution thread changes
# mid-computation.
call_count["test_property"] += 1
time.sleep(np.random.random() + 1.)
# Use a RandomState which is seeded off the instance's id (the mod is
# because numpy limits the range of seeds) to ensure that an instance
# returns the same value in different threads, but different instances
# return different values.
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
def get_test_property(self, _):
"""Function provided to .map for threading test."""
return self.test_property
# Test that multiple threads return the same value. This requires that
# the underlying function is repeatable, as cached_property makes no attempt
# to prioritize the first call.
test_obj = MyObject()
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
# Intentionally make a large pool (even when there are only a small number
# of cpus) to ensure that the runtime switches threads.
results = pool.map(test_obj.get_test_property, range(64))
self.assertEqual(len(set(results)), 1)
# Make sure we actually are testing threaded behavior.
self.assertGreater(call_count["test_property"], 1)
# Make sure new threads still cache hit.
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
start_time = timeit.default_timer() # Don't time pool instantiation.
results = pool.map(test_obj.get_test_property, range(4))
total_time = timeit.default_timer() - start_time
# Note(taylorrobie): The reason that it is safe to time a unit test is that
# a cache hit will be << 1 second, and a cache miss is
# guaranteed to be >= 1 second. Empirically confirmed by
# 100,000 runs with no flakes.
self.assertLess(total_time, 0.95)
def test_property_cache_serialization(self):
# Reset call count. .keys() must be wrapped in a list, because otherwise we
# would mutate the iterator while iterating.
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
_PICKLEABLE_CALL_COUNT.pop(k)
first_instance = MyPickleableObject()
self.assertEqual(id(first_instance), first_instance.my_id)
# Test that we can pickle and un-pickle
second_instance = pickle.loads(pickle.dumps(first_instance))
self.assertEqual(id(second_instance), second_instance.my_id)
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
# Make sure de-serialized object uses the cache.
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
# Make sure the decorator cache is not being serialized with the object.
expected_size = len(pickle.dumps(second_instance))
for _ in range(5):
# Add some more entries to the cache.
_ = MyPickleableObject().my_id
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
size_check_instance = MyPickleableObject()
_ = size_check_instance.my_id
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
class _DummyResource(tracking.TrackableResource):