Reduce Python overheads for lazily constructed Layer properties.

PiperOrigin-RevId: 248435402
This commit is contained in:
Taylor Robie 2019-05-15 17:08:20 -07:00 committed by TensorFlower Gardener
parent 83668b0826
commit 42ac719705
4 changed files with 245 additions and 9 deletions

View File

@ -2046,12 +2046,12 @@ class Layer(module.Module):
return True return True
@property @property
@tracking.cached_per_instance
def _call_fn_args(self): def _call_fn_args(self):
if getattr(self, '__call_fn_args', None) is None: return function_utils.fn_args(self.call)
self.__call_fn_args = function_utils.fn_args(self.call)
return self.__call_fn_args
@property @property
@tracking.cached_per_instance
def _should_compute_mask(self): def _should_compute_mask(self):
return ('mask' in self._call_fn_args or return ('mask' in self._call_fn_args or
getattr(self, 'compute_mask', None) is not None) getattr(self, 'compute_mask', None) is not None)

View File

@ -50,6 +50,7 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization from tensorflow.python.util import serialization
@ -513,6 +514,7 @@ class Network(base_layer.Layer):
return weights return weights
@property @property
@tracking.cached_per_instance
def _should_compute_mask(self): def _should_compute_mask(self):
return self._is_graph_network and super(Network, self)._should_compute_mask return self._is_graph_network and super(Network, self)._should_compute_mask

View File

@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import weakref
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -241,6 +244,100 @@ class TrackableAsset(base.Trackable):
"""Fetch the current asset path.""" """Fetch the current asset path."""
return self._path return self._path
def cached_per_instance(f):
"""Lightweight decorator for caching lazily constructed properties.
When to use:
This decorator provides simple caching with minimal overhead. It is designed
for properties which are expensive to compute and static over the life of a
class instance, and provides no mechanism for cache invalidation. Thus it is
best suited for lazily exposing derived properties of other static data.
For classes with custom getattr / setattr behavior (such as trackable
objects), storing cache results as object attributes is not performant.
Instead, a specialized cache can significantly reduce property lookup
overhead. (While still allowing the decorated property to be lazily computed.)
Consider the following class:
```
class MyClass(object):
def __setattr__(self, key, value):
# Some expensive class specific code
# ...
# ...
super(MyClass, self).__setattr__(key, value)
@property
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
output = getattr(self, '_thing', None)
if output is None:
self._thing = output = compute_thing(self)
return output
```
It's also worth noting that ANY overriding of __setattr__, even something as
simple as:
```
def __setattr__(self, key, value):
super(MyClass, self).__setattr__(key, value)
```
Slows down attribute assignment by nearly 10x.
By contrast, replacing the definition of `thing` with the following sidesteps
the expensive __setattr__ altogether:
'''
@property
@tracking.cached_per_instance
def thing(self):
# `thing` is expensive to compute (and may not even be requested), so we
# want to lazily compute it and then cache it.
return compute_thing(self)
'''
Performance:
The overhead for this decorator is ~0.4 us / call. A much lower overhead
implementation (~0.085 us / call) can be achieved by using a custom dict type:
```
def dict_based_cache(f):
class Cache(dict):
__slots__ = ()
def __missing__(self, key):
self[key] = output = f(key)
return output
return property(Cache().__getitem__)
```
However, that implementation holds class instances as keys, and as a result
blocks garbage collection. (And modifying it to use weakref's as keys raises
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
implementation below turns out to be more prudent.
Args:
f: The function to cache.
Returns:
f decorated with simple caching behavior.
"""
cache = weakref.WeakKeyDictionary()
@functools.wraps(f)
def wrapped(item):
output = cache.get(item)
if output is None:
cache[item] = output = f(item)
return output
return wrapped
ops.register_tensor_conversion_function( ops.register_tensor_conversion_function(
TrackableAsset, TrackableAsset,
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw)) lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))

View File

@ -16,9 +16,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import contextlib
import multiprocessing.dummy
import os import os
import pickle
import time
import timeit
import numpy import numpy as np
import six import six
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -32,6 +38,23 @@ from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest from tensorflow.python.util import nest
_PICKLEABLE_CALL_COUNT = collections.Counter()
class MyPickleableObject(tracking.AutoTrackable):
"""Needed for InterfaceTests.test_property_cache_serialization.
This class must be at the top level. This is a constraint of pickle,
unrelated to `cached_per_instance`.
"""
@property
@tracking.cached_per_instance
def my_id(self):
_PICKLEABLE_CALL_COUNT[self] += 1
return id(self)
class InterfaceTests(test.TestCase): class InterfaceTests(test.TestCase):
def testMultipleAssignment(self): def testMultipleAssignment(self):
@ -199,15 +222,129 @@ class InterfaceTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testAssertions(self): def testAssertions(self):
a = tracking.AutoTrackable() a = tracking.AutoTrackable()
a.l = {"k": [numpy.zeros([2, 2])]} a.l = {"k": [np.zeros([2, 2])]}
self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}),
nest.flatten(a.l)) nest.flatten(a.l))
self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l)
nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]})
a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
self.evaluate(a.tensors)) self.evaluate(a.tensors))
def test_property_cache(self):
test_counter = collections.Counter()
class MyObject(tracking.AutoTrackable):
def __init__(self):
super(MyObject, self).__init__()
self._frozen = True
def __setattr__(self, key, value):
"""Enforce that cache does not set attribute on MyObject."""
if getattr(self, "_frozen", False):
raise ValueError("Cannot mutate when frozen.")
return super(MyObject, self).__setattr__(key, value)
@property
@tracking.cached_per_instance
def test_property(self):
test_counter[id(self)] += 1
return id(self)
first_object = MyObject()
second_object = MyObject()
# Make sure the objects return the correct values
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Make sure the cache does not share across objects
self.assertNotEqual(first_object.test_property, second_object.test_property)
# Check again (Now the values should be cached.)
self.assertEqual(first_object.test_property, id(first_object))
self.assertEqual(second_object.test_property, id(second_object))
# Count the function calls to make sure the cache is actually being used.
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
def test_property_cache_threaded(self):
call_count = collections.Counter()
class MyObject(tracking.AutoTrackable):
@property
@tracking.cached_per_instance
def test_property(self):
# Random sleeps to ensure that the execution thread changes
# mid-computation.
call_count["test_property"] += 1
time.sleep(np.random.random() + 1.)
# Use a RandomState which is seeded off the instance's id (the mod is
# because numpy limits the range of seeds) to ensure that an instance
# returns the same value in different threads, but different instances
# return different values.
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
def get_test_property(self, _):
"""Function provided to .map for threading test."""
return self.test_property
# Test that multiple threads return the same value. This requires that
# the underlying function is repeatable, as cached_property makes no attempt
# to prioritize the first call.
test_obj = MyObject()
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
# Intentionally make a large pool (even when there are only a small number
# of cpus) to ensure that the runtime switches threads.
results = pool.map(test_obj.get_test_property, range(64))
self.assertEqual(len(set(results)), 1)
# Make sure we actually are testing threaded behavior.
self.assertGreater(call_count["test_property"], 1)
# Make sure new threads still cache hit.
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
start_time = timeit.default_timer() # Don't time pool instantiation.
results = pool.map(test_obj.get_test_property, range(4))
total_time = timeit.default_timer() - start_time
# Note(taylorrobie): The reason that it is safe to time a unit test is that
# a cache hit will be << 1 second, and a cache miss is
# guaranteed to be >= 1 second. Emperically confirmed by
# 100,000 runs with no flakes.
self.assertLess(total_time, 0.95)
def test_property_cache_serialization(self):
# Reset call count. .keys() must be wrapped in a list, because otherwise we
# would mutate the iterator while iterating.
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
_PICKLEABLE_CALL_COUNT.pop(k)
first_instance = MyPickleableObject()
self.assertEqual(id(first_instance), first_instance.my_id)
# Test that we can pickle and un-pickle
second_instance = pickle.loads(pickle.dumps(first_instance))
self.assertEqual(id(second_instance), second_instance.my_id)
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
# Make sure de-serialized object uses the cache.
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
# Make sure the decorator cache is not being serialized with the object.
expected_size = len(pickle.dumps(second_instance))
for _ in range(5):
# Add some more entries to the cache.
_ = MyPickleableObject().my_id
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
size_check_instance = MyPickleableObject()
_ = size_check_instance.my_id
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
class _DummyResource(tracking.TrackableResource): class _DummyResource(tracking.TrackableResource):