Reduce Python overheads for lazily constructed Layer properties.
PiperOrigin-RevId: 248435402
This commit is contained in:
parent
83668b0826
commit
42ac719705
@ -2046,12 +2046,12 @@ class Layer(module.Module):
|
||||
return True
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _call_fn_args(self):
|
||||
if getattr(self, '__call_fn_args', None) is None:
|
||||
self.__call_fn_args = function_utils.fn_args(self.call)
|
||||
return self.__call_fn_args
|
||||
return function_utils.fn_args(self.call)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _should_compute_mask(self):
|
||||
return ('mask' in self._call_fn_args or
|
||||
getattr(self, 'compute_mask', None) is not None)
|
||||
|
@ -50,6 +50,7 @@ from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
@ -513,6 +514,7 @@ class Network(base_layer.Layer):
|
||||
return weights
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _should_compute_mask(self):
|
||||
return self._is_graph_network and super(Network, self)._should_compute_mask
|
||||
|
||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -241,6 +244,100 @@ class TrackableAsset(base.Trackable):
|
||||
"""Fetch the current asset path."""
|
||||
return self._path
|
||||
|
||||
|
||||
def cached_per_instance(f):
|
||||
"""Lightweight decorator for caching lazily constructed properties.
|
||||
|
||||
When to use:
|
||||
This decorator provides simple caching with minimal overhead. It is designed
|
||||
for properties which are expensive to compute and static over the life of a
|
||||
class instance, and provides no mechanism for cache invalidation. Thus it is
|
||||
best suited for lazily exposing derived properties of other static data.
|
||||
|
||||
For classes with custom getattr / setattr behavior (such as trackable
|
||||
objects), storing cache results as object attributes is not performant.
|
||||
Instead, a specialized cache can significantly reduce property lookup
|
||||
overhead. (While still allowing the decorated property to be lazily computed.)
|
||||
Consider the following class:
|
||||
|
||||
```
|
||||
class MyClass(object):
|
||||
def __setattr__(self, key, value):
|
||||
# Some expensive class specific code
|
||||
# ...
|
||||
# ...
|
||||
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
output = getattr(self, '_thing', None)
|
||||
if output is None:
|
||||
self._thing = output = compute_thing(self)
|
||||
return output
|
||||
```
|
||||
|
||||
It's also worth noting that ANY overriding of __setattr__, even something as
|
||||
simple as:
|
||||
```
|
||||
def __setattr__(self, key, value):
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
```
|
||||
|
||||
Slows down attribute assignment by nearly 10x.
|
||||
|
||||
By contrast, replacing the definition of `thing` with the following sidesteps
|
||||
the expensive __setattr__ altogether:
|
||||
|
||||
'''
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
return compute_thing(self)
|
||||
'''
|
||||
|
||||
Performance:
|
||||
The overhead for this decorator is ~0.4 us / call. A much lower overhead
|
||||
implementation (~0.085 us / call) can be achieved by using a custom dict type:
|
||||
|
||||
```
|
||||
def dict_based_cache(f):
|
||||
class Cache(dict):
|
||||
__slots__ = ()
|
||||
def __missing__(self, key):
|
||||
self[key] = output = f(key)
|
||||
return output
|
||||
|
||||
return property(Cache().__getitem__)
|
||||
```
|
||||
|
||||
However, that implementation holds class instances as keys, and as a result
|
||||
blocks garbage collection. (And modifying it to use weakref's as keys raises
|
||||
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
|
||||
implementation below turns out to be more prudent.
|
||||
|
||||
Args:
|
||||
f: The function to cache.
|
||||
|
||||
Returns:
|
||||
f decorated with simple caching behavior.
|
||||
"""
|
||||
|
||||
cache = weakref.WeakKeyDictionary()
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(item):
|
||||
output = cache.get(item)
|
||||
if output is None:
|
||||
cache[item] = output = f(item)
|
||||
return output
|
||||
return wrapped
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
TrackableAsset,
|
||||
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))
|
||||
|
@ -16,9 +16,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import multiprocessing.dummy
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import timeit
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -32,6 +38,23 @@ from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
_PICKLEABLE_CALL_COUNT = collections.Counter()
|
||||
|
||||
|
||||
class MyPickleableObject(tracking.AutoTrackable):
|
||||
"""Needed for InterfaceTests.test_property_cache_serialization.
|
||||
|
||||
This class must be at the top level. This is a constraint of pickle,
|
||||
unrelated to `cached_per_instance`.
|
||||
"""
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def my_id(self):
|
||||
_PICKLEABLE_CALL_COUNT[self] += 1
|
||||
return id(self)
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testMultipleAssignment(self):
|
||||
@ -199,15 +222,129 @@ class InterfaceTests(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertions(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.l = {"k": [numpy.zeros([2, 2])]}
|
||||
self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}),
|
||||
a.l = {"k": [np.zeros([2, 2])]}
|
||||
self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}),
|
||||
nest.flatten(a.l))
|
||||
self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l)
|
||||
nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]})
|
||||
self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l)
|
||||
nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]})
|
||||
a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
|
||||
self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
|
||||
self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
|
||||
self.evaluate(a.tensors))
|
||||
|
||||
def test_property_cache(self):
|
||||
test_counter = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(MyObject, self).__init__()
|
||||
self._frozen = True
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""Enforce that cache does not set attribute on MyObject."""
|
||||
if getattr(self, "_frozen", False):
|
||||
raise ValueError("Cannot mutate when frozen.")
|
||||
return super(MyObject, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def test_property(self):
|
||||
test_counter[id(self)] += 1
|
||||
return id(self)
|
||||
|
||||
first_object = MyObject()
|
||||
second_object = MyObject()
|
||||
|
||||
# Make sure the objects return the correct values
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Make sure the cache does not share across objects
|
||||
self.assertNotEqual(first_object.test_property, second_object.test_property)
|
||||
|
||||
# Check again (Now the values should be cached.)
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Count the function calls to make sure the cache is actually being used.
|
||||
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
|
||||
|
||||
def test_property_cache_threaded(self):
|
||||
call_count = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def test_property(self):
|
||||
# Random sleeps to ensure that the execution thread changes
|
||||
# mid-computation.
|
||||
call_count["test_property"] += 1
|
||||
time.sleep(np.random.random() + 1.)
|
||||
|
||||
# Use a RandomState which is seeded off the instance's id (the mod is
|
||||
# because numpy limits the range of seeds) to ensure that an instance
|
||||
# returns the same value in different threads, but different instances
|
||||
# return different values.
|
||||
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
|
||||
|
||||
def get_test_property(self, _):
|
||||
"""Function provided to .map for threading test."""
|
||||
return self.test_property
|
||||
|
||||
# Test that multiple threads return the same value. This requires that
|
||||
# the underlying function is repeatable, as cached_property makes no attempt
|
||||
# to prioritize the first call.
|
||||
test_obj = MyObject()
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
|
||||
# Intentionally make a large pool (even when there are only a small number
|
||||
# of cpus) to ensure that the runtime switches threads.
|
||||
results = pool.map(test_obj.get_test_property, range(64))
|
||||
self.assertEqual(len(set(results)), 1)
|
||||
|
||||
# Make sure we actually are testing threaded behavior.
|
||||
self.assertGreater(call_count["test_property"], 1)
|
||||
|
||||
# Make sure new threads still cache hit.
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
|
||||
start_time = timeit.default_timer() # Don't time pool instantiation.
|
||||
results = pool.map(test_obj.get_test_property, range(4))
|
||||
total_time = timeit.default_timer() - start_time
|
||||
|
||||
# Note(taylorrobie): The reason that it is safe to time a unit test is that
|
||||
# a cache hit will be << 1 second, and a cache miss is
|
||||
# guaranteed to be >= 1 second. Emperically confirmed by
|
||||
# 100,000 runs with no flakes.
|
||||
self.assertLess(total_time, 0.95)
|
||||
|
||||
def test_property_cache_serialization(self):
|
||||
# Reset call count. .keys() must be wrapped in a list, because otherwise we
|
||||
# would mutate the iterator while iterating.
|
||||
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
|
||||
_PICKLEABLE_CALL_COUNT.pop(k)
|
||||
|
||||
first_instance = MyPickleableObject()
|
||||
self.assertEqual(id(first_instance), first_instance.my_id)
|
||||
|
||||
# Test that we can pickle and un-pickle
|
||||
second_instance = pickle.loads(pickle.dumps(first_instance))
|
||||
|
||||
self.assertEqual(id(second_instance), second_instance.my_id)
|
||||
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
|
||||
|
||||
# Make sure de-serialized object uses the cache.
|
||||
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
|
||||
|
||||
# Make sure the decorator cache is not being serialized with the object.
|
||||
expected_size = len(pickle.dumps(second_instance))
|
||||
for _ in range(5):
|
||||
# Add some more entries to the cache.
|
||||
_ = MyPickleableObject().my_id
|
||||
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
|
||||
size_check_instance = MyPickleableObject()
|
||||
_ = size_check_instance.my_id
|
||||
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
|
||||
|
||||
|
||||
class _DummyResource(tracking.TrackableResource):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user