diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index c01c3d96aec..5ac0a6dd997 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2926,14 +2926,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._call_accepts_kwargs) @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_full_argspec(self): # Argspec inspection is expensive and the call spec is used often, so it # makes sense to cache the result. return tf_inspect.getfullargspec(self.call) @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_fn_args(self): all_args = self._call_full_argspec.args # Scrub `self` that appears if a decorator was applied. @@ -2942,7 +2942,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return all_args @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_fn_arg_defaults(self): call_fn_args = self._call_fn_args call_fn_defaults = self._call_full_argspec.defaults or [] @@ -2955,7 +2955,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return defaults @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_fn_arg_positions(self): call_fn_arg_positions = dict() for pos, arg in enumerate(self._call_fn_args): @@ -2963,7 +2963,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return call_fn_arg_positions @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_accepts_kwargs(self): return self._call_full_argspec.varkw is not None diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index f047d84d16a..536efb52ad1 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -2342,14 +2342,14 @@ class Layer(base_layer.Layer): self._call_accepts_kwargs) @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_full_argspec(self): # Argspec inspection is expensive and the call spec is used often, so it # makes sense to cache the result. return tf_inspect.getfullargspec(self.call) @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_fn_args(self): all_args = self._call_full_argspec.args # Scrub `self` that appears if a decorator was applied. @@ -2358,7 +2358,7 @@ class Layer(base_layer.Layer): return all_args @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_fn_arg_positions(self): call_fn_arg_positions = dict() for pos, arg in enumerate(self._call_fn_args): @@ -2366,12 +2366,12 @@ class Layer(base_layer.Layer): return call_fn_arg_positions @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _call_accepts_kwargs(self): return self._call_full_argspec.varkw is not None @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _should_compute_mask(self): return ('mask' in self._call_fn_args or getattr(self, 'compute_mask', None) is not None) diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 9a317e5d114..d5341006e46 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -49,6 +49,7 @@ py_library( "//tensorflow/python/keras:backend_config", "//tensorflow/python/keras:initializers", "//tensorflow/python/keras/engine:base_layer_utils", + "//tensorflow/python/keras/utils:layer_utils", "//tensorflow/python/keras/utils:tf_utils", ], ) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index c533b2c40c1..e6b4458ca8d 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -39,6 +39,7 @@ from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -48,7 +49,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.saved_model import revived_types from tensorflow.python.training.tracking import base as trackable -from tensorflow.python.training.tracking import tracking from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export @@ -1207,12 +1207,12 @@ class OptimizerV2(trackable.Trackable): return x.value() @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _dense_apply_args(self): return tf_inspect.getfullargspec(self._resource_apply_dense).args @property - @tracking.cached_per_instance + @layer_utils.cached_per_instance def _sparse_apply_args(self): return tf_inspect.getfullargspec(self._resource_apply_sparse).args diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 899701d624c..38e3c8e66af 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -301,6 +301,19 @@ tf_py_test( ], ) +tf_py_test( + name = "layer_utils_test", + size = "small", + srcs = ["layer_utils_test.py"], + python_version = "PY3", + deps = [ + ":layer_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python/training/tracking", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "np_utils_test", size = "small", diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index d2d3d919fff..3195bb0eb13 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -19,6 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import weakref + import numpy as np import six @@ -404,3 +407,99 @@ def is_builtin_layer(layer): # of the base layer class. return (layer._keras_api_names != ('keras.layers.Layer',) and layer._keras_api_names_v1 != ('keras.layers.Layer',)) + + +def cached_per_instance(f): + """Lightweight decorator for caching lazily constructed properties. + + When to use: + This decorator provides simple caching with minimal overhead. It is designed + for properties which are expensive to compute and static over the life of a + class instance, and provides no mechanism for cache invalidation. Thus it is + best suited for lazily exposing derived properties of other static data. + + For classes with custom getattr / setattr behavior (such as trackable + objects), storing cache results as object attributes is not performant. + Instead, a specialized cache can significantly reduce property lookup + overhead. (While still allowing the decorated property to be lazily computed.) + Consider the following class: + + ``` + class MyClass(object): + def __setattr__(self, key, value): + # Some expensive class specific code + # ... + # ... + + super(MyClass, self).__setattr__(key, value) + + @property + def thing(self): + # `thing` is expensive to compute (and may not even be requested), so we + # want to lazily compute it and then cache it. + output = getattr(self, '_thing', None) + if output is None: + self._thing = output = compute_thing(self) + return output + ``` + + It's also worth noting that ANY overriding of __setattr__, even something as + simple as: + ``` + def __setattr__(self, key, value): + super(MyClass, self).__setattr__(key, value) + ``` + + Slows down attribute assignment by nearly 10x. + + By contrast, replacing the definition of `thing` with the following sidesteps + the expensive __setattr__ altogether: + + ''' + @property + @tracking.cached_per_instance + def thing(self): + # `thing` is expensive to compute (and may not even be requested), so we + # want to lazily compute it and then cache it. + return compute_thing(self) + ''' + + Performance: + The overhead for this decorator is ~0.4 us / call. A much lower overhead + implementation (~0.085 us / call) can be achieved by using a custom dict type: + + ``` + def dict_based_cache(f): + class Cache(dict): + __slots__ = () + def __missing__(self, key): + self[key] = output = f(key) + return output + + return property(Cache().__getitem__) + ``` + + However, that implementation holds class instances as keys, and as a result + blocks garbage collection. (And modifying it to use weakref's as keys raises + the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary + implementation below turns out to be more prudent. + + Args: + f: The function to cache. + + Returns: + f decorated with simple caching behavior. + """ + + cache = weakref.WeakKeyDictionary() + + @functools.wraps(f) + def wrapped(item): + output = cache.get(item) + if output is None: + cache[item] = output = f(item) + return output + + wrapped.cache = cache + return wrapped + diff --git a/tensorflow/python/keras/utils/layer_utils_test.py b/tensorflow/python/keras/utils/layer_utils_test.py new file mode 100644 index 00000000000..a4e53a21aba --- /dev/null +++ b/tensorflow/python/keras/utils/layer_utils_test.py @@ -0,0 +1,170 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for layer_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +import multiprocessing.dummy +import pickle +import time +import timeit + +import numpy as np + +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.platform import test +from tensorflow.python.training.tracking import tracking + + +_PICKLEABLE_CALL_COUNT = collections.Counter() + + +class MyPickleableObject(tracking.AutoTrackable): + """Needed for InterfaceTests.test_property_cache_serialization. + + This class must be at the top level. This is a constraint of pickle, + unrelated to `cached_per_instance`. + """ + + @property + @layer_utils.cached_per_instance + def my_id(self): + _PICKLEABLE_CALL_COUNT[self] += 1 + return id(self) + + +class LayerUtilsTest(test.TestCase): + + def test_property_cache(self): + test_counter = collections.Counter() + + class MyObject(tracking.AutoTrackable): + + def __init__(self): + super(MyObject, self).__init__() + self._frozen = True + + def __setattr__(self, key, value): + """Enforce that cache does not set attribute on MyObject.""" + if getattr(self, "_frozen", False): + raise ValueError("Cannot mutate when frozen.") + return super(MyObject, self).__setattr__(key, value) + + @property + @layer_utils.cached_per_instance + def test_property(self): + test_counter[id(self)] += 1 + return id(self) + + first_object = MyObject() + second_object = MyObject() + + # Make sure the objects return the correct values + self.assertEqual(first_object.test_property, id(first_object)) + self.assertEqual(second_object.test_property, id(second_object)) + + # Make sure the cache does not share across objects + self.assertNotEqual(first_object.test_property, second_object.test_property) + + # Check again (Now the values should be cached.) + self.assertEqual(first_object.test_property, id(first_object)) + self.assertEqual(second_object.test_property, id(second_object)) + + # Count the function calls to make sure the cache is actually being used. + self.assertAllEqual(tuple(test_counter.values()), (1, 1)) + + def test_property_cache_threaded(self): + call_count = collections.Counter() + + class MyObject(tracking.AutoTrackable): + + @property + @layer_utils.cached_per_instance + def test_property(self): + # Random sleeps to ensure that the execution thread changes + # mid-computation. + call_count["test_property"] += 1 + time.sleep(np.random.random() + 1.) + + # Use a RandomState which is seeded off the instance's id (the mod is + # because numpy limits the range of seeds) to ensure that an instance + # returns the same value in different threads, but different instances + # return different values. + return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16)) + + def get_test_property(self, _): + """Function provided to .map for threading test.""" + return self.test_property + + # Test that multiple threads return the same value. This requires that + # the underlying function is repeatable, as cached_property makes no attempt + # to prioritize the first call. + test_obj = MyObject() + with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool: + # Intentionally make a large pool (even when there are only a small number + # of cpus) to ensure that the runtime switches threads. + results = pool.map(test_obj.get_test_property, range(64)) + self.assertEqual(len(set(results)), 1) + + # Make sure we actually are testing threaded behavior. + self.assertGreater(call_count["test_property"], 1) + + # Make sure new threads still cache hit. + with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool: + start_time = timeit.default_timer() # Don't time pool instantiation. + results = pool.map(test_obj.get_test_property, range(4)) + total_time = timeit.default_timer() - start_time + + # Note(taylorrobie): The reason that it is safe to time a unit test is that + # a cache hit will be << 1 second, and a cache miss is + # guaranteed to be >= 1 second. Empirically confirmed by + # 100,000 runs with no flakes. + self.assertLess(total_time, 0.95) + + def test_property_cache_serialization(self): + # Reset call count. .keys() must be wrapped in a list, because otherwise we + # would mutate the iterator while iterating. + for k in list(_PICKLEABLE_CALL_COUNT.keys()): + _PICKLEABLE_CALL_COUNT.pop(k) + + first_instance = MyPickleableObject() + self.assertEqual(id(first_instance), first_instance.my_id) + + # Test that we can pickle and un-pickle + second_instance = pickle.loads(pickle.dumps(first_instance)) + + self.assertEqual(id(second_instance), second_instance.my_id) + self.assertNotEqual(first_instance.my_id, second_instance.my_id) + + # Make sure de-serialized object uses the cache. + self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1) + + # Make sure the decorator cache is not being serialized with the object. + expected_size = len(pickle.dumps(second_instance)) + for _ in range(5): + # Add some more entries to the cache. + _ = MyPickleableObject().my_id + self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7) + size_check_instance = MyPickleableObject() + _ = size_check_instance.my_id + self.assertEqual(expected_size, len(pickle.dumps(size_check_instance))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index 8a27cc37cb2..6b8bf3bd19d 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -18,8 +18,6 @@ from __future__ import division from __future__ import print_function import copy -import functools -import weakref from absl import logging @@ -357,100 +355,5 @@ class Asset(base.Trackable): return self._path -def cached_per_instance(f): - """Lightweight decorator for caching lazily constructed properties. - - When to use: - This decorator provides simple caching with minimal overhead. It is designed - for properties which are expensive to compute and static over the life of a - class instance, and provides no mechanism for cache invalidation. Thus it is - best suited for lazily exposing derived properties of other static data. - - For classes with custom getattr / setattr behavior (such as trackable - objects), storing cache results as object attributes is not performant. - Instead, a specialized cache can significantly reduce property lookup - overhead. (While still allowing the decorated property to be lazily computed.) - Consider the following class: - - ``` - class MyClass(object): - def __setattr__(self, key, value): - # Some expensive class specific code - # ... - # ... - - super(MyClass, self).__setattr__(key, value) - - @property - def thing(self): - # `thing` is expensive to compute (and may not even be requested), so we - # want to lazily compute it and then cache it. - output = getattr(self, '_thing', None) - if output is None: - self._thing = output = compute_thing(self) - return output - ``` - - It's also worth noting that ANY overriding of __setattr__, even something as - simple as: - ``` - def __setattr__(self, key, value): - super(MyClass, self).__setattr__(key, value) - ``` - - Slows down attribute assignment by nearly 10x. - - By contrast, replacing the definition of `thing` with the following sidesteps - the expensive __setattr__ altogether: - - ''' - @property - @tracking.cached_per_instance - def thing(self): - # `thing` is expensive to compute (and may not even be requested), so we - # want to lazily compute it and then cache it. - return compute_thing(self) - ''' - - Performance: - The overhead for this decorator is ~0.4 us / call. A much lower overhead - implementation (~0.085 us / call) can be achieved by using a custom dict type: - - ``` - def dict_based_cache(f): - class Cache(dict): - __slots__ = () - def __missing__(self, key): - self[key] = output = f(key) - return output - - return property(Cache().__getitem__) - ``` - - However, that implementation holds class instances as keys, and as a result - blocks garbage collection. (And modifying it to use weakref's as keys raises - the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary - implementation below turns out to be more prudent. - - Args: - f: The function to cache. - - Returns: - f decorated with simple caching behavior. - """ - - cache = weakref.WeakKeyDictionary() - - @functools.wraps(f) - def wrapped(item): - output = cache.get(item) - if output is None: - cache[item] = output = f(item) - return output - - wrapped.cache = cache - return wrapped - - ops.register_tensor_conversion_function( Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw)) diff --git a/tensorflow/python/training/tracking/tracking_test.py b/tensorflow/python/training/tracking/tracking_test.py index e2b01964bb3..3d6be8c0f4b 100644 --- a/tensorflow/python/training/tracking/tracking_test.py +++ b/tensorflow/python/training/tracking/tracking_test.py @@ -16,13 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import contextlib -import multiprocessing.dummy import os -import pickle -import time -import timeit import numpy as np @@ -35,23 +29,6 @@ from tensorflow.python.training.tracking import util from tensorflow.python.util import nest -_PICKLEABLE_CALL_COUNT = collections.Counter() - - -class MyPickleableObject(tracking.AutoTrackable): - """Needed for InterfaceTests.test_property_cache_serialization. - - This class must be at the top level. This is a constraint of pickle, - unrelated to `cached_per_instance`. - """ - - @property - @tracking.cached_per_instance - def my_id(self): - _PICKLEABLE_CALL_COUNT[self] += 1 - return id(self) - - class InterfaceTests(test.TestCase): def testMultipleAssignment(self): @@ -169,120 +146,6 @@ class InterfaceTests(test.TestCase): self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]}, self.evaluate(a.tensors)) - def test_property_cache(self): - test_counter = collections.Counter() - - class MyObject(tracking.AutoTrackable): - - def __init__(self): - super(MyObject, self).__init__() - self._frozen = True - - def __setattr__(self, key, value): - """Enforce that cache does not set attribute on MyObject.""" - if getattr(self, "_frozen", False): - raise ValueError("Cannot mutate when frozen.") - return super(MyObject, self).__setattr__(key, value) - - @property - @tracking.cached_per_instance - def test_property(self): - test_counter[id(self)] += 1 - return id(self) - - first_object = MyObject() - second_object = MyObject() - - # Make sure the objects return the correct values - self.assertEqual(first_object.test_property, id(first_object)) - self.assertEqual(second_object.test_property, id(second_object)) - - # Make sure the cache does not share across objects - self.assertNotEqual(first_object.test_property, second_object.test_property) - - # Check again (Now the values should be cached.) - self.assertEqual(first_object.test_property, id(first_object)) - self.assertEqual(second_object.test_property, id(second_object)) - - # Count the function calls to make sure the cache is actually being used. - self.assertAllEqual(tuple(test_counter.values()), (1, 1)) - - def test_property_cache_threaded(self): - call_count = collections.Counter() - - class MyObject(tracking.AutoTrackable): - - @property - @tracking.cached_per_instance - def test_property(self): - # Random sleeps to ensure that the execution thread changes - # mid-computation. - call_count["test_property"] += 1 - time.sleep(np.random.random() + 1.) - - # Use a RandomState which is seeded off the instance's id (the mod is - # because numpy limits the range of seeds) to ensure that an instance - # returns the same value in different threads, but different instances - # return different values. - return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16)) - - def get_test_property(self, _): - """Function provided to .map for threading test.""" - return self.test_property - - # Test that multiple threads return the same value. This requires that - # the underlying function is repeatable, as cached_property makes no attempt - # to prioritize the first call. - test_obj = MyObject() - with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool: - # Intentionally make a large pool (even when there are only a small number - # of cpus) to ensure that the runtime switches threads. - results = pool.map(test_obj.get_test_property, range(64)) - self.assertEqual(len(set(results)), 1) - - # Make sure we actually are testing threaded behavior. - self.assertGreater(call_count["test_property"], 1) - - # Make sure new threads still cache hit. - with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool: - start_time = timeit.default_timer() # Don't time pool instantiation. - results = pool.map(test_obj.get_test_property, range(4)) - total_time = timeit.default_timer() - start_time - - # Note(taylorrobie): The reason that it is safe to time a unit test is that - # a cache hit will be << 1 second, and a cache miss is - # guaranteed to be >= 1 second. Empirically confirmed by - # 100,000 runs with no flakes. - self.assertLess(total_time, 0.95) - - def test_property_cache_serialization(self): - # Reset call count. .keys() must be wrapped in a list, because otherwise we - # would mutate the iterator while iterating. - for k in list(_PICKLEABLE_CALL_COUNT.keys()): - _PICKLEABLE_CALL_COUNT.pop(k) - - first_instance = MyPickleableObject() - self.assertEqual(id(first_instance), first_instance.my_id) - - # Test that we can pickle and un-pickle - second_instance = pickle.loads(pickle.dumps(first_instance)) - - self.assertEqual(id(second_instance), second_instance.my_id) - self.assertNotEqual(first_instance.my_id, second_instance.my_id) - - # Make sure de-serialized object uses the cache. - self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1) - - # Make sure the decorator cache is not being serialized with the object. - expected_size = len(pickle.dumps(second_instance)) - for _ in range(5): - # Add some more entries to the cache. - _ = MyPickleableObject().my_id - self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7) - size_check_instance = MyPickleableObject() - _ = size_check_instance.my_id - self.assertEqual(expected_size, len(pickle.dumps(size_check_instance))) - class _DummyResource(tracking.TrackableResource):