Reduce Python overheads for lazily constructed Layer properties.

PiperOrigin-RevId: 248435402
This commit is contained in:
Taylor Robie 2019-05-15 17:08:20 -07:00 committed by TensorFlower Gardener
parent 83668b0826
commit 42ac719705
4 changed files with 245 additions and 9 deletions

View File

@ -2046,12 +2046,12 @@ class Layer(module.Module):
return True
@property
@tracking.cached_per_instance
def _call_fn_args(self):
if getattr(self, '__call_fn_args', None) is None:
self.__call_fn_args = function_utils.fn_args(self.call)
return self.__call_fn_args
return function_utils.fn_args(self.call)
@property
@tracking.cached_per_instance
def _should_compute_mask(self):
return ('mask' in self._call_fn_args or
getattr(self, 'compute_mask', None) is not None)

View File

@ -50,6 +50,7 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
@ -513,6 +514,7 @@ class Network(base_layer.Layer):
return weights
@property
@tracking.cached_per_instance
def _should_compute_mask(self):
return self._is_graph_network and super(Network, self)._should_compute_mask

View File

@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import weakref
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes
@ -241,6 +244,100 @@ class TrackableAsset(base.Trackable):
"""Fetch the current asset path."""
return self._path
def cached_per_instance(f):
"""Lightweight decorator for caching lazily constructed properties.
When to use:
This decorator provides simple caching with minimal overhead. It is designed
for properties which are expensive to compute and static over the life of a
class instance, and provides no mechanism for cache invalidation. Thus it is
best suited for lazily exposing derived properties of other static data.
For classes with custom getattr / setattr behavior (such as trackable
objects), storing cache results as object attributes is not performant.
Instead, a specialized cache can significantly reduce property lookup
overhead. (While still allowing the decorated property to be lazily computed.)
Consider the following class:
```
class MyClass(object):
def __setattr__(self, key, value):
# Some expensive class specific code
# ...
# ...
super(MyClass, self).__setattr__(key, value)
@property
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
output = getattr(self, '_thing', None)
if output is None:
self._thing = output = compute_thing(self)
return output
```
It's also worth noting that ANY overriding of __setattr__, even something as
simple as:
```
def __setattr__(self, key, value):
super(MyClass, self).__setattr__(key, value)
```
Slows down attribute assignment by nearly 10x.
By contrast, replacing the definition of `thing` with the following sidesteps
the expensive __setattr__ altogether:
'''
@property
@tracking.cached_per_instance
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
return compute_thing(self)
'''
Performance:
The overhead for this decorator is ~0.4 us / call. A much lower overhead
implementation (~0.085 us / call) can be achieved by using a custom dict type:
```
def dict_based_cache(f):
class Cache(dict):
__slots__ = ()
def __missing__(self, key):
self[key] = output = f(key)
return output
return property(Cache().__getitem__)
```
However, that implementation holds class instances as keys, and as a result
blocks garbage collection. (And modifying it to use weakref's as keys raises
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
implementation below turns out to be more prudent.
Args:
f: The function to cache.
Returns:
f decorated with simple caching behavior.
"""
cache = weakref.WeakKeyDictionary()
@functools.wraps(f)
def wrapped(item):
output = cache.get(item)
if output is None:
cache[item] = output = f(item)
return output
return wrapped
ops.register_tensor_conversion_function(
TrackableAsset,
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))

View File

@ -16,9 +16,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import multiprocessing.dummy
import os
import pickle
import time
import timeit
import numpy
import numpy as np
import six
from tensorflow.python.framework import test_util
@ -32,6 +38,23 @@ from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
_PICKLEABLE_CALL_COUNT = collections.Counter()
class MyPickleableObject(tracking.AutoTrackable):
"""Needed for InterfaceTests.test_property_cache_serialization.
This class must be at the top level. This is a constraint of pickle,
unrelated to `cached_per_instance`.
"""
@property
@tracking.cached_per_instance
def my_id(self):
_PICKLEABLE_CALL_COUNT[self] += 1
return id(self)
class InterfaceTests(test.TestCase):
def testMultipleAssignment(self):
@ -199,15 +222,129 @@ class InterfaceTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testAssertions(self):
a = tracking.AutoTrackable()
a.l = {"k": [numpy.zeros([2, 2])]}
self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}),
a.l = {"k": [np.zeros([2, 2])]}
self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}),
nest.flatten(a.l))
self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l)
nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]})
self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l)
nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]})
a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
self.evaluate(a.tensors))
def test_property_cache(self):
test_counter = collections.Counter()
class MyObject(tracking.AutoTrackable):
def __init__(self):
super(MyObject, self).__init__()
self._frozen = True
def __setattr__(self, key, value):
"""Enforce that cache does not set attribute on MyObject."""
if getattr(self, "_frozen", False):
raise ValueError("Cannot mutate when frozen.")
return super(MyObject, self).__setattr__(key, value)
@property
@tracking.cached_per_instance
def test_property(self):
test_counter[id(self)] += 1
return id(self)
first_object = MyObject()
second_object = MyObject()
# Make sure the objects return the correct values
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Make sure the cache does not share across objects
self.assertNotEqual(first_object.test_property, second_object.test_property)
# Check again (Now the values should be cached.)
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Count the function calls to make sure the cache is actually being used.
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
def test_property_cache_threaded(self):
call_count = collections.Counter()
class MyObject(tracking.AutoTrackable):
@property
@tracking.cached_per_instance
def test_property(self):
# Random sleeps to ensure that the execution thread changes
# mid-computation.
call_count["test_property"] += 1
time.sleep(np.random.random() + 1.)
# Use a RandomState which is seeded off the instance's id (the mod is
# because numpy limits the range of seeds) to ensure that an instance
# returns the same value in different threads, but different instances
# return different values.
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
def get_test_property(self, _):
"""Function provided to .map for threading test."""
return self.test_property
# Test that multiple threads return the same value. This requires that
# the underlying function is repeatable, as cached_property makes no attempt
# to prioritize the first call.
test_obj = MyObject()
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
# Intentionally make a large pool (even when there are only a small number
# of cpus) to ensure that the runtime switches threads.
results = pool.map(test_obj.get_test_property, range(64))
self.assertEqual(len(set(results)), 1)
# Make sure we actually are testing threaded behavior.
self.assertGreater(call_count["test_property"], 1)
# Make sure new threads still cache hit.
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
start_time = timeit.default_timer() # Don't time pool instantiation.
results = pool.map(test_obj.get_test_property, range(4))
total_time = timeit.default_timer() - start_time
# Note(taylorrobie): The reason that it is safe to time a unit test is that
# a cache hit will be << 1 second, and a cache miss is
# guaranteed to be >= 1 second. Emperically confirmed by
# 100,000 runs with no flakes.
self.assertLess(total_time, 0.95)
def test_property_cache_serialization(self):
# Reset call count. .keys() must be wrapped in a list, because otherwise we
# would mutate the iterator while iterating.
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
_PICKLEABLE_CALL_COUNT.pop(k)
first_instance = MyPickleableObject()
self.assertEqual(id(first_instance), first_instance.my_id)
# Test that we can pickle and un-pickle
second_instance = pickle.loads(pickle.dumps(first_instance))
self.assertEqual(id(second_instance), second_instance.my_id)
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
# Make sure de-serialized object uses the cache.
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
# Make sure the decorator cache is not being serialized with the object.
expected_size = len(pickle.dumps(second_instance))
for _ in range(5):
# Add some more entries to the cache.
_ = MyPickleableObject().my_id
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
size_check_instance = MyPickleableObject()
_ = size_check_instance.my_id
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
class _DummyResource(tracking.TrackableResource):