Move cached_per_instance to keras utils since its only used in Keras.
PiperOrigin-RevId: 326677839 Change-Id: I04fb71d17241b65fc1d5ae8f69e4d40770357bf7
This commit is contained in:
parent
f4bae1839c
commit
57e69437b4
@ -2926,14 +2926,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._call_accepts_kwargs)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_full_argspec(self):
|
||||
# Argspec inspection is expensive and the call spec is used often, so it
|
||||
# makes sense to cache the result.
|
||||
return tf_inspect.getfullargspec(self.call)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_fn_args(self):
|
||||
all_args = self._call_full_argspec.args
|
||||
# Scrub `self` that appears if a decorator was applied.
|
||||
@ -2942,7 +2942,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
return all_args
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_fn_arg_defaults(self):
|
||||
call_fn_args = self._call_fn_args
|
||||
call_fn_defaults = self._call_full_argspec.defaults or []
|
||||
@ -2955,7 +2955,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
return defaults
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_fn_arg_positions(self):
|
||||
call_fn_arg_positions = dict()
|
||||
for pos, arg in enumerate(self._call_fn_args):
|
||||
@ -2963,7 +2963,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
return call_fn_arg_positions
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_accepts_kwargs(self):
|
||||
return self._call_full_argspec.varkw is not None
|
||||
|
||||
|
@ -2342,14 +2342,14 @@ class Layer(base_layer.Layer):
|
||||
self._call_accepts_kwargs)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_full_argspec(self):
|
||||
# Argspec inspection is expensive and the call spec is used often, so it
|
||||
# makes sense to cache the result.
|
||||
return tf_inspect.getfullargspec(self.call)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_fn_args(self):
|
||||
all_args = self._call_full_argspec.args
|
||||
# Scrub `self` that appears if a decorator was applied.
|
||||
@ -2358,7 +2358,7 @@ class Layer(base_layer.Layer):
|
||||
return all_args
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_fn_arg_positions(self):
|
||||
call_fn_arg_positions = dict()
|
||||
for pos, arg in enumerate(self._call_fn_args):
|
||||
@ -2366,12 +2366,12 @@ class Layer(base_layer.Layer):
|
||||
return call_fn_arg_positions
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _call_accepts_kwargs(self):
|
||||
return self._call_full_argspec.varkw is not None
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _should_compute_mask(self):
|
||||
return ('mask' in self._call_fn_args or
|
||||
getattr(self, 'compute_mask', None) is not None)
|
||||
|
@ -49,6 +49,7 @@ py_library(
|
||||
"//tensorflow/python/keras:backend_config",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras/engine:base_layer_utils",
|
||||
"//tensorflow/python/keras/utils:layer_utils",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
],
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -48,7 +49,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -1207,12 +1207,12 @@ class OptimizerV2(trackable.Trackable):
|
||||
return x.value()
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _dense_apply_args(self):
|
||||
return tf_inspect.getfullargspec(self._resource_apply_dense).args
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
@layer_utils.cached_per_instance
|
||||
def _sparse_apply_args(self):
|
||||
return tf_inspect.getfullargspec(self._resource_apply_sparse).args
|
||||
|
||||
|
@ -301,6 +301,19 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "layer_utils_test",
|
||||
size = "small",
|
||||
srcs = ["layer_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":layer_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "np_utils_test",
|
||||
size = "small",
|
||||
|
@ -19,6 +19,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
@ -404,3 +407,99 @@ def is_builtin_layer(layer):
|
||||
# of the base layer class.
|
||||
return (layer._keras_api_names != ('keras.layers.Layer',) and
|
||||
layer._keras_api_names_v1 != ('keras.layers.Layer',))
|
||||
|
||||
|
||||
def cached_per_instance(f):
|
||||
"""Lightweight decorator for caching lazily constructed properties.
|
||||
|
||||
When to use:
|
||||
This decorator provides simple caching with minimal overhead. It is designed
|
||||
for properties which are expensive to compute and static over the life of a
|
||||
class instance, and provides no mechanism for cache invalidation. Thus it is
|
||||
best suited for lazily exposing derived properties of other static data.
|
||||
|
||||
For classes with custom getattr / setattr behavior (such as trackable
|
||||
objects), storing cache results as object attributes is not performant.
|
||||
Instead, a specialized cache can significantly reduce property lookup
|
||||
overhead. (While still allowing the decorated property to be lazily computed.)
|
||||
Consider the following class:
|
||||
|
||||
```
|
||||
class MyClass(object):
|
||||
def __setattr__(self, key, value):
|
||||
# Some expensive class specific code
|
||||
# ...
|
||||
# ...
|
||||
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
output = getattr(self, '_thing', None)
|
||||
if output is None:
|
||||
self._thing = output = compute_thing(self)
|
||||
return output
|
||||
```
|
||||
|
||||
It's also worth noting that ANY overriding of __setattr__, even something as
|
||||
simple as:
|
||||
```
|
||||
def __setattr__(self, key, value):
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
```
|
||||
|
||||
Slows down attribute assignment by nearly 10x.
|
||||
|
||||
By contrast, replacing the definition of `thing` with the following sidesteps
|
||||
the expensive __setattr__ altogether:
|
||||
|
||||
'''
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
return compute_thing(self)
|
||||
'''
|
||||
|
||||
Performance:
|
||||
The overhead for this decorator is ~0.4 us / call. A much lower overhead
|
||||
implementation (~0.085 us / call) can be achieved by using a custom dict type:
|
||||
|
||||
```
|
||||
def dict_based_cache(f):
|
||||
class Cache(dict):
|
||||
__slots__ = ()
|
||||
def __missing__(self, key):
|
||||
self[key] = output = f(key)
|
||||
return output
|
||||
|
||||
return property(Cache().__getitem__)
|
||||
```
|
||||
|
||||
However, that implementation holds class instances as keys, and as a result
|
||||
blocks garbage collection. (And modifying it to use weakref's as keys raises
|
||||
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
|
||||
implementation below turns out to be more prudent.
|
||||
|
||||
Args:
|
||||
f: The function to cache.
|
||||
|
||||
Returns:
|
||||
f decorated with simple caching behavior.
|
||||
"""
|
||||
|
||||
cache = weakref.WeakKeyDictionary()
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(item):
|
||||
output = cache.get(item)
|
||||
if output is None:
|
||||
cache[item] = output = f(item)
|
||||
return output
|
||||
|
||||
wrapped.cache = cache
|
||||
return wrapped
|
||||
|
||||
|
170
tensorflow/python/keras/utils/layer_utils_test.py
Normal file
170
tensorflow/python/keras/utils/layer_utils_test.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for layer_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import multiprocessing.dummy
|
||||
import pickle
|
||||
import time
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
_PICKLEABLE_CALL_COUNT = collections.Counter()
|
||||
|
||||
|
||||
class MyPickleableObject(tracking.AutoTrackable):
|
||||
"""Needed for InterfaceTests.test_property_cache_serialization.
|
||||
|
||||
This class must be at the top level. This is a constraint of pickle,
|
||||
unrelated to `cached_per_instance`.
|
||||
"""
|
||||
|
||||
@property
|
||||
@layer_utils.cached_per_instance
|
||||
def my_id(self):
|
||||
_PICKLEABLE_CALL_COUNT[self] += 1
|
||||
return id(self)
|
||||
|
||||
|
||||
class LayerUtilsTest(test.TestCase):
|
||||
|
||||
def test_property_cache(self):
|
||||
test_counter = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(MyObject, self).__init__()
|
||||
self._frozen = True
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""Enforce that cache does not set attribute on MyObject."""
|
||||
if getattr(self, "_frozen", False):
|
||||
raise ValueError("Cannot mutate when frozen.")
|
||||
return super(MyObject, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
@layer_utils.cached_per_instance
|
||||
def test_property(self):
|
||||
test_counter[id(self)] += 1
|
||||
return id(self)
|
||||
|
||||
first_object = MyObject()
|
||||
second_object = MyObject()
|
||||
|
||||
# Make sure the objects return the correct values
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Make sure the cache does not share across objects
|
||||
self.assertNotEqual(first_object.test_property, second_object.test_property)
|
||||
|
||||
# Check again (Now the values should be cached.)
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Count the function calls to make sure the cache is actually being used.
|
||||
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
|
||||
|
||||
def test_property_cache_threaded(self):
|
||||
call_count = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
@property
|
||||
@layer_utils.cached_per_instance
|
||||
def test_property(self):
|
||||
# Random sleeps to ensure that the execution thread changes
|
||||
# mid-computation.
|
||||
call_count["test_property"] += 1
|
||||
time.sleep(np.random.random() + 1.)
|
||||
|
||||
# Use a RandomState which is seeded off the instance's id (the mod is
|
||||
# because numpy limits the range of seeds) to ensure that an instance
|
||||
# returns the same value in different threads, but different instances
|
||||
# return different values.
|
||||
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
|
||||
|
||||
def get_test_property(self, _):
|
||||
"""Function provided to .map for threading test."""
|
||||
return self.test_property
|
||||
|
||||
# Test that multiple threads return the same value. This requires that
|
||||
# the underlying function is repeatable, as cached_property makes no attempt
|
||||
# to prioritize the first call.
|
||||
test_obj = MyObject()
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
|
||||
# Intentionally make a large pool (even when there are only a small number
|
||||
# of cpus) to ensure that the runtime switches threads.
|
||||
results = pool.map(test_obj.get_test_property, range(64))
|
||||
self.assertEqual(len(set(results)), 1)
|
||||
|
||||
# Make sure we actually are testing threaded behavior.
|
||||
self.assertGreater(call_count["test_property"], 1)
|
||||
|
||||
# Make sure new threads still cache hit.
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
|
||||
start_time = timeit.default_timer() # Don't time pool instantiation.
|
||||
results = pool.map(test_obj.get_test_property, range(4))
|
||||
total_time = timeit.default_timer() - start_time
|
||||
|
||||
# Note(taylorrobie): The reason that it is safe to time a unit test is that
|
||||
# a cache hit will be << 1 second, and a cache miss is
|
||||
# guaranteed to be >= 1 second. Empirically confirmed by
|
||||
# 100,000 runs with no flakes.
|
||||
self.assertLess(total_time, 0.95)
|
||||
|
||||
def test_property_cache_serialization(self):
|
||||
# Reset call count. .keys() must be wrapped in a list, because otherwise we
|
||||
# would mutate the iterator while iterating.
|
||||
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
|
||||
_PICKLEABLE_CALL_COUNT.pop(k)
|
||||
|
||||
first_instance = MyPickleableObject()
|
||||
self.assertEqual(id(first_instance), first_instance.my_id)
|
||||
|
||||
# Test that we can pickle and un-pickle
|
||||
second_instance = pickle.loads(pickle.dumps(first_instance))
|
||||
|
||||
self.assertEqual(id(second_instance), second_instance.my_id)
|
||||
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
|
||||
|
||||
# Make sure de-serialized object uses the cache.
|
||||
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
|
||||
|
||||
# Make sure the decorator cache is not being serialized with the object.
|
||||
expected_size = len(pickle.dumps(second_instance))
|
||||
for _ in range(5):
|
||||
# Add some more entries to the cache.
|
||||
_ = MyPickleableObject().my_id
|
||||
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
|
||||
size_check_instance = MyPickleableObject()
|
||||
_ = size_check_instance.my_id
|
||||
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -18,8 +18,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -357,100 +355,5 @@ class Asset(base.Trackable):
|
||||
return self._path
|
||||
|
||||
|
||||
def cached_per_instance(f):
|
||||
"""Lightweight decorator for caching lazily constructed properties.
|
||||
|
||||
When to use:
|
||||
This decorator provides simple caching with minimal overhead. It is designed
|
||||
for properties which are expensive to compute and static over the life of a
|
||||
class instance, and provides no mechanism for cache invalidation. Thus it is
|
||||
best suited for lazily exposing derived properties of other static data.
|
||||
|
||||
For classes with custom getattr / setattr behavior (such as trackable
|
||||
objects), storing cache results as object attributes is not performant.
|
||||
Instead, a specialized cache can significantly reduce property lookup
|
||||
overhead. (While still allowing the decorated property to be lazily computed.)
|
||||
Consider the following class:
|
||||
|
||||
```
|
||||
class MyClass(object):
|
||||
def __setattr__(self, key, value):
|
||||
# Some expensive class specific code
|
||||
# ...
|
||||
# ...
|
||||
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
output = getattr(self, '_thing', None)
|
||||
if output is None:
|
||||
self._thing = output = compute_thing(self)
|
||||
return output
|
||||
```
|
||||
|
||||
It's also worth noting that ANY overriding of __setattr__, even something as
|
||||
simple as:
|
||||
```
|
||||
def __setattr__(self, key, value):
|
||||
super(MyClass, self).__setattr__(key, value)
|
||||
```
|
||||
|
||||
Slows down attribute assignment by nearly 10x.
|
||||
|
||||
By contrast, replacing the definition of `thing` with the following sidesteps
|
||||
the expensive __setattr__ altogether:
|
||||
|
||||
'''
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def thing(self):
|
||||
# `thing` is expensive to compute (and may not even be requested), so we
|
||||
# want to lazily compute it and then cache it.
|
||||
return compute_thing(self)
|
||||
'''
|
||||
|
||||
Performance:
|
||||
The overhead for this decorator is ~0.4 us / call. A much lower overhead
|
||||
implementation (~0.085 us / call) can be achieved by using a custom dict type:
|
||||
|
||||
```
|
||||
def dict_based_cache(f):
|
||||
class Cache(dict):
|
||||
__slots__ = ()
|
||||
def __missing__(self, key):
|
||||
self[key] = output = f(key)
|
||||
return output
|
||||
|
||||
return property(Cache().__getitem__)
|
||||
```
|
||||
|
||||
However, that implementation holds class instances as keys, and as a result
|
||||
blocks garbage collection. (And modifying it to use weakref's as keys raises
|
||||
the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
|
||||
implementation below turns out to be more prudent.
|
||||
|
||||
Args:
|
||||
f: The function to cache.
|
||||
|
||||
Returns:
|
||||
f decorated with simple caching behavior.
|
||||
"""
|
||||
|
||||
cache = weakref.WeakKeyDictionary()
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(item):
|
||||
output = cache.get(item)
|
||||
if output is None:
|
||||
cache[item] = output = f(item)
|
||||
return output
|
||||
|
||||
wrapped.cache = cache
|
||||
return wrapped
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
Asset, lambda asset, **kw: ops.convert_to_tensor(asset.asset_path, **kw))
|
||||
|
@ -16,13 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import multiprocessing.dummy
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -35,23 +29,6 @@ from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
_PICKLEABLE_CALL_COUNT = collections.Counter()
|
||||
|
||||
|
||||
class MyPickleableObject(tracking.AutoTrackable):
|
||||
"""Needed for InterfaceTests.test_property_cache_serialization.
|
||||
|
||||
This class must be at the top level. This is a constraint of pickle,
|
||||
unrelated to `cached_per_instance`.
|
||||
"""
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def my_id(self):
|
||||
_PICKLEABLE_CALL_COUNT[self] += 1
|
||||
return id(self)
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testMultipleAssignment(self):
|
||||
@ -169,120 +146,6 @@ class InterfaceTests(test.TestCase):
|
||||
self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
|
||||
self.evaluate(a.tensors))
|
||||
|
||||
def test_property_cache(self):
|
||||
test_counter = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(MyObject, self).__init__()
|
||||
self._frozen = True
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""Enforce that cache does not set attribute on MyObject."""
|
||||
if getattr(self, "_frozen", False):
|
||||
raise ValueError("Cannot mutate when frozen.")
|
||||
return super(MyObject, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def test_property(self):
|
||||
test_counter[id(self)] += 1
|
||||
return id(self)
|
||||
|
||||
first_object = MyObject()
|
||||
second_object = MyObject()
|
||||
|
||||
# Make sure the objects return the correct values
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Make sure the cache does not share across objects
|
||||
self.assertNotEqual(first_object.test_property, second_object.test_property)
|
||||
|
||||
# Check again (Now the values should be cached.)
|
||||
self.assertEqual(first_object.test_property, id(first_object))
|
||||
self.assertEqual(second_object.test_property, id(second_object))
|
||||
|
||||
# Count the function calls to make sure the cache is actually being used.
|
||||
self.assertAllEqual(tuple(test_counter.values()), (1, 1))
|
||||
|
||||
def test_property_cache_threaded(self):
|
||||
call_count = collections.Counter()
|
||||
|
||||
class MyObject(tracking.AutoTrackable):
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def test_property(self):
|
||||
# Random sleeps to ensure that the execution thread changes
|
||||
# mid-computation.
|
||||
call_count["test_property"] += 1
|
||||
time.sleep(np.random.random() + 1.)
|
||||
|
||||
# Use a RandomState which is seeded off the instance's id (the mod is
|
||||
# because numpy limits the range of seeds) to ensure that an instance
|
||||
# returns the same value in different threads, but different instances
|
||||
# return different values.
|
||||
return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
|
||||
|
||||
def get_test_property(self, _):
|
||||
"""Function provided to .map for threading test."""
|
||||
return self.test_property
|
||||
|
||||
# Test that multiple threads return the same value. This requires that
|
||||
# the underlying function is repeatable, as cached_property makes no attempt
|
||||
# to prioritize the first call.
|
||||
test_obj = MyObject()
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
|
||||
# Intentionally make a large pool (even when there are only a small number
|
||||
# of cpus) to ensure that the runtime switches threads.
|
||||
results = pool.map(test_obj.get_test_property, range(64))
|
||||
self.assertEqual(len(set(results)), 1)
|
||||
|
||||
# Make sure we actually are testing threaded behavior.
|
||||
self.assertGreater(call_count["test_property"], 1)
|
||||
|
||||
# Make sure new threads still cache hit.
|
||||
with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
|
||||
start_time = timeit.default_timer() # Don't time pool instantiation.
|
||||
results = pool.map(test_obj.get_test_property, range(4))
|
||||
total_time = timeit.default_timer() - start_time
|
||||
|
||||
# Note(taylorrobie): The reason that it is safe to time a unit test is that
|
||||
# a cache hit will be << 1 second, and a cache miss is
|
||||
# guaranteed to be >= 1 second. Empirically confirmed by
|
||||
# 100,000 runs with no flakes.
|
||||
self.assertLess(total_time, 0.95)
|
||||
|
||||
def test_property_cache_serialization(self):
|
||||
# Reset call count. .keys() must be wrapped in a list, because otherwise we
|
||||
# would mutate the iterator while iterating.
|
||||
for k in list(_PICKLEABLE_CALL_COUNT.keys()):
|
||||
_PICKLEABLE_CALL_COUNT.pop(k)
|
||||
|
||||
first_instance = MyPickleableObject()
|
||||
self.assertEqual(id(first_instance), first_instance.my_id)
|
||||
|
||||
# Test that we can pickle and un-pickle
|
||||
second_instance = pickle.loads(pickle.dumps(first_instance))
|
||||
|
||||
self.assertEqual(id(second_instance), second_instance.my_id)
|
||||
self.assertNotEqual(first_instance.my_id, second_instance.my_id)
|
||||
|
||||
# Make sure de-serialized object uses the cache.
|
||||
self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
|
||||
|
||||
# Make sure the decorator cache is not being serialized with the object.
|
||||
expected_size = len(pickle.dumps(second_instance))
|
||||
for _ in range(5):
|
||||
# Add some more entries to the cache.
|
||||
_ = MyPickleableObject().my_id
|
||||
self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
|
||||
size_check_instance = MyPickleableObject()
|
||||
_ = size_check_instance.my_id
|
||||
self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
|
||||
|
||||
|
||||
class _DummyResource(tracking.TrackableResource):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user