diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index cefa4013f1c..86c0e00f8b9 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2046,12 +2046,12 @@ class Layer(module.Module): return True @property + @tracking.cached_per_instance def _call_fn_args(self): - if getattr(self, '__call_fn_args', None) is None: - self.__call_fn_args = function_utils.fn_args(self.call) - return self.__call_fn_args + return function_utils.fn_args(self.call) @property + @tracking.cached_per_instance def _should_compute_mask(self): return ('mask' in self._call_fn_args or getattr(self, 'compute_mask', None) is not None) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 0178867c946..e6e6f3448f4 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -50,6 +50,7 @@ from tensorflow.python.training import checkpoint_management from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils +from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.util import nest from tensorflow.python.util import serialization @@ -513,6 +514,7 @@ class Network(base_layer.Layer): return weights @property + @tracking.cached_per_instance def _should_compute_mask(self): return self._is_graph_network and super(Network, self)._should_compute_mask diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index c61b697de1b..26392305932 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -17,6 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import weakref + from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun from tensorflow.python.framework import dtypes @@ -241,6 +244,100 @@ class TrackableAsset(base.Trackable): """Fetch the current asset path.""" return self._path + +def cached_per_instance(f): + """Lightweight decorator for caching lazily constructed properties. + + When to use: + This decorator provides simple caching with minimal overhead. It is designed + for properties which are expensive to compute and static over the life of a + class instance, and provides no mechanism for cache invalidation. Thus it is + best suited for lazily exposing derived properties of other static data. + + For classes with custom getattr / setattr behavior (such as trackable + objects), storing cache results as object attributes is not performant. + Instead, a specialized cache can significantly reduce property lookup + overhead. (While still allowing the decorated property to be lazily computed.) + Consider the following class: + + ``` + class MyClass(object): + def __setattr__(self, key, value): + # Some expensive class specific code + # ... + # ... + + super(MyClass, self).__setattr__(key, value) + + @property + def thing(self): + # `thing` is expensive to compute (and may not even be requested), so we + # want to lazily compute it and then cache it. + output = getattr(self, '_thing', None) + if output is None: + self._thing = output = compute_thing(self) + return output + ``` + + It's also worth noting that ANY overriding of __setattr__, even something as + simple as: + ``` + def __setattr__(self, key, value): + super(MyClass, self).__setattr__(key, value) + ``` + + Slows down attribute assignment by nearly 10x. + + By contrast, replacing the definition of `thing` with the following sidesteps + the expensive __setattr__ altogether: + + ''' + @property + @tracking.cached_per_instance + def thing(self): + # `thing` is expensive to compute (and may not even be requested), so we + # want to lazily compute it and then cache it. + return compute_thing(self) + ''' + + Performance: + The overhead for this decorator is ~0.4 us / call. A much lower overhead + implementation (~0.085 us / call) can be achieved by using a custom dict type: + + ``` + def dict_based_cache(f): + class Cache(dict): + __slots__ = () + def __missing__(self, key): + self[key] = output = f(key) + return output + + return property(Cache().__getitem__) + ``` + + However, that implementation holds class instances as keys, and as a result + blocks garbage collection. (And modifying it to use weakref's as keys raises + the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary + implementation below turns out to be more prudent. + + Args: + f: The function to cache. + + Returns: + f decorated with simple caching behavior. + """ + + cache = weakref.WeakKeyDictionary() + + @functools.wraps(f) + def wrapped(item): + output = cache.get(item) + if output is None: + cache[item] = output = f(item) + return output + return wrapped + + ops.register_tensor_conversion_function( TrackableAsset, lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw)) diff --git a/tensorflow/python/training/tracking/tracking_test.py b/tensorflow/python/training/tracking/tracking_test.py index adef69f45bd..90e6c6cbd53 100644 --- a/tensorflow/python/training/tracking/tracking_test.py +++ b/tensorflow/python/training/tracking/tracking_test.py @@ -16,9 +16,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import contextlib +import multiprocessing.dummy import os +import pickle +import time +import timeit -import numpy +import numpy as np import six from tensorflow.python.framework import test_util @@ -32,6 +38,23 @@ from tensorflow.python.training.tracking import util from tensorflow.python.util import nest +_PICKLEABLE_CALL_COUNT = collections.Counter() + + +class MyPickleableObject(tracking.AutoTrackable): + """Needed for InterfaceTests.test_property_cache_serialization. + + This class must be at the top level. This is a constraint of pickle, + unrelated to `cached_per_instance`. + """ + + @property + @tracking.cached_per_instance + def my_id(self): + _PICKLEABLE_CALL_COUNT[self] += 1 + return id(self) + + class InterfaceTests(test.TestCase): def testMultipleAssignment(self): @@ -199,15 +222,129 @@ class InterfaceTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testAssertions(self): a = tracking.AutoTrackable() - a.l = {"k": [numpy.zeros([2, 2])]} - self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), + a.l = {"k": [np.zeros([2, 2])]} + self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}), nest.flatten(a.l)) - self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) - nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) + self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l) + nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]}) a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} - self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, + self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]}, self.evaluate(a.tensors)) + def test_property_cache(self): + test_counter = collections.Counter() + + class MyObject(tracking.AutoTrackable): + + def __init__(self): + super(MyObject, self).__init__() + self._frozen = True + + def __setattr__(self, key, value): + """Enforce that cache does not set attribute on MyObject.""" + if getattr(self, "_frozen", False): + raise ValueError("Cannot mutate when frozen.") + return super(MyObject, self).__setattr__(key, value) + + @property + @tracking.cached_per_instance + def test_property(self): + test_counter[id(self)] += 1 + return id(self) + + first_object = MyObject() + second_object = MyObject() + + # Make sure the objects return the correct values + self.assertEqual(first_object.test_property, id(first_object)) + self.assertEqual(second_object.test_property, id(second_object)) + + # Make sure the cache does not share across objects + self.assertNotEqual(first_object.test_property, second_object.test_property) + + # Check again (Now the values should be cached.) + self.assertEqual(first_object.test_property, id(first_object)) + self.assertEqual(second_object.test_property, id(second_object)) + + # Count the function calls to make sure the cache is actually being used. + self.assertAllEqual(tuple(test_counter.values()), (1, 1)) + + def test_property_cache_threaded(self): + call_count = collections.Counter() + + class MyObject(tracking.AutoTrackable): + + @property + @tracking.cached_per_instance + def test_property(self): + # Random sleeps to ensure that the execution thread changes + # mid-computation. + call_count["test_property"] += 1 + time.sleep(np.random.random() + 1.) + + # Use a RandomState which is seeded off the instance's id (the mod is + # because numpy limits the range of seeds) to ensure that an instance + # returns the same value in different threads, but different instances + # return different values. + return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16)) + + def get_test_property(self, _): + """Function provided to .map for threading test.""" + return self.test_property + + # Test that multiple threads return the same value. This requires that + # the underlying function is repeatable, as cached_property makes no attempt + # to prioritize the first call. + test_obj = MyObject() + with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool: + # Intentionally make a large pool (even when there are only a small number + # of cpus) to ensure that the runtime switches threads. + results = pool.map(test_obj.get_test_property, range(64)) + self.assertEqual(len(set(results)), 1) + + # Make sure we actually are testing threaded behavior. + self.assertGreater(call_count["test_property"], 1) + + # Make sure new threads still cache hit. + with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool: + start_time = timeit.default_timer() # Don't time pool instantiation. + results = pool.map(test_obj.get_test_property, range(4)) + total_time = timeit.default_timer() - start_time + + # Note(taylorrobie): The reason that it is safe to time a unit test is that + # a cache hit will be << 1 second, and a cache miss is + # guaranteed to be >= 1 second. Emperically confirmed by + # 100,000 runs with no flakes. + self.assertLess(total_time, 0.95) + + def test_property_cache_serialization(self): + # Reset call count. .keys() must be wrapped in a list, because otherwise we + # would mutate the iterator while iterating. + for k in list(_PICKLEABLE_CALL_COUNT.keys()): + _PICKLEABLE_CALL_COUNT.pop(k) + + first_instance = MyPickleableObject() + self.assertEqual(id(first_instance), first_instance.my_id) + + # Test that we can pickle and un-pickle + second_instance = pickle.loads(pickle.dumps(first_instance)) + + self.assertEqual(id(second_instance), second_instance.my_id) + self.assertNotEqual(first_instance.my_id, second_instance.my_id) + + # Make sure de-serialized object uses the cache. + self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1) + + # Make sure the decorator cache is not being serialized with the object. + expected_size = len(pickle.dumps(second_instance)) + for _ in range(5): + # Add some more entries to the cache. + _ = MyPickleableObject().my_id + self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7) + size_check_instance = MyPickleableObject() + _ = size_check_instance.my_id + self.assertEqual(expected_size, len(pickle.dumps(size_check_instance))) + class _DummyResource(tracking.TrackableResource):