Reduce 1-layer Functional.__call__ overhead by ~10%
Adds a Layer._flatten_layers(recursive=True, include_self=True) method. Uses this method for gathering unique sublayers, maintaining backwards compatible ordering. Removes unnecessary attribute caching for stateful and dynamic properties. PiperOrigin-RevId: 314229252 Change-Id: I08cb80ae27861c52eae1ebed068b9a10d803e8a0
This commit is contained in:
parent
edeae9fb69
commit
b9d99cb6f5
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
@ -1019,31 +1020,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
||||
def dynamic(self):
|
||||
"""Whether the layer is dynamic (eager-only); set in the constructor."""
|
||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
||||
# then this cache logic must be updated.
|
||||
return self._dynamic or any(layer.dynamic
|
||||
for layer in self._unique_sublayers())
|
||||
|
||||
def _unique_sublayers(self):
|
||||
# Model.layers will use this as implementation, but we can't expose this
|
||||
# one as the public property since it might conflict with subclass layers
|
||||
# which also have user defined layers property.
|
||||
self._maybe_create_attribute('_layers', [])
|
||||
return list(
|
||||
trackable_layer_utils.filter_empty_layer_containers(self._layers))
|
||||
return any(layer._dynamic for layer in self._flatten_layers())
|
||||
|
||||
@property
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
||||
def stateful(self):
|
||||
return self._stateful or any(
|
||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
||||
return any(layer._stateful for layer in self._flatten_layers())
|
||||
|
||||
@stateful.setter
|
||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
||||
def stateful(self, value):
|
||||
self._stateful = value
|
||||
|
||||
@ -1053,9 +1039,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
|
||||
@trainable.setter
|
||||
def trainable(self, value):
|
||||
self._trainable = value
|
||||
for layer in getattr(self, '_layers', []):
|
||||
layer.trainable = value
|
||||
for layer in self._flatten_layers():
|
||||
layer._trainable = value
|
||||
|
||||
@property
|
||||
def activity_regularizer(self):
|
||||
@ -1162,7 +1147,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
def updates(self):
|
||||
collected_updates = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
all_layers = self._flatten_layers()
|
||||
with backend.get_graph().as_default():
|
||||
for layer in all_layers:
|
||||
if not layer.trainable and not layer.stateful:
|
||||
@ -1215,8 +1200,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
A list of tensors.
|
||||
"""
|
||||
collected_losses = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
for layer in all_layers:
|
||||
for layer in self._flatten_layers():
|
||||
# If any eager losses are present, we assume the model to be part of an
|
||||
# eager training loop (either a custom one or the one used when
|
||||
# `run_eagerly=True`) and so we always return just the eager losses.
|
||||
@ -1357,12 +1341,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
def _clear_losses(self):
|
||||
"""Used every step in eager to reset losses."""
|
||||
# Set to thread local directly to avoid Layer.__setattr__ overhead.
|
||||
self._thread_local._eager_losses = []
|
||||
sublayers = getattr(self, '_layers', [])
|
||||
if sublayers:
|
||||
sublayers = trackable_layer_utils.filter_empty_layer_containers(sublayers)
|
||||
for layer in sublayers:
|
||||
layer._clear_losses()
|
||||
if not getattr(self, '_layers', None): # Fast path for single Layer.
|
||||
self._thread_local._eager_losses = []
|
||||
else:
|
||||
for layer in self._flatten_layers():
|
||||
layer._thread_local._eager_losses = []
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
@ -1382,8 +1365,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
A list of tensors.
|
||||
"""
|
||||
collected_metrics = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
for layer in all_layers:
|
||||
for layer in self._flatten_layers():
|
||||
with layer._metrics_lock:
|
||||
collected_metrics.extend(layer._metrics)
|
||||
return collected_metrics
|
||||
@ -2507,22 +2489,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
Returns:
|
||||
A dict mapping all sublayers to their `trainable` value.
|
||||
"""
|
||||
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
|
||||
# Keep track of each top-level layers' `trainable` as well as the
|
||||
# state of all of its sublayers.
|
||||
trainable_state = weakref.WeakKeyDictionary()
|
||||
trainable_state[self] = self.trainable
|
||||
for layer in layers:
|
||||
trainable_state.update(layer._get_trainable_state())
|
||||
for layer in self._flatten_layers():
|
||||
trainable_state[layer] = layer.trainable
|
||||
return trainable_state
|
||||
|
||||
def _set_trainable_state(self, trainable_state):
|
||||
"""Set `trainable` state for each sublayer."""
|
||||
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
|
||||
if self in trainable_state:
|
||||
self.trainable = trainable_state[self]
|
||||
for layer in layers:
|
||||
layer._set_trainable_state(trainable_state)
|
||||
for layer in self._flatten_layers():
|
||||
if layer in trainable_state:
|
||||
layer.trainable = trainable_state[layer]
|
||||
|
||||
@property
|
||||
def _obj_reference_counts(self):
|
||||
@ -2582,7 +2558,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
super(tracking.AutoTrackable, self).__setattr__(
|
||||
'_layers',
|
||||
[l for l in self._layers if l is not existing_value])
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
if isinstance(existing_value, tf_variables.Variable):
|
||||
super(tracking.AutoTrackable, self).__setattr__(
|
||||
'_trainable_weights',
|
||||
@ -2591,13 +2566,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
'_non_trainable_weights',
|
||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||
|
||||
# Any time we change `_layers` (either by deleting the attribute or by
|
||||
# reassigning it which will call __delattr__ from __setattr__) the topology
|
||||
# of the subgraph of Layers may change. In that case we will need to
|
||||
# recompute any attribute which depends on that subgraph.
|
||||
if name == '_layers':
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (name == '_self_setattr_tracking' or
|
||||
not getattr(self, '_self_setattr_tracking', True) or
|
||||
@ -2642,8 +2610,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# container types which compare equal.
|
||||
if not any((layer is value for layer in self._layers)):
|
||||
self._layers.append(value)
|
||||
if hasattr(value, '_attribute_sentinel'):
|
||||
value._attribute_sentinel.add_parent(self._attribute_sentinel)
|
||||
if hasattr(value, '_use_resource_variables'):
|
||||
# Legacy layers (V1 tf.layers) must always use
|
||||
# resource variables.
|
||||
@ -2691,34 +2657,36 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
getattr(layer, attribute) for layer in nested_layers))
|
||||
return []
|
||||
|
||||
def _gather_unique_layers(self):
|
||||
"""Returns the current layer and all its children depth first deduped.
|
||||
def _flatten_layers(self, recursive=True, include_self=True):
|
||||
if include_self:
|
||||
yield self
|
||||
|
||||
We are deduping after getting the layers to maintain the order.
|
||||
"""
|
||||
all_layers = self._gather_layers()
|
||||
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
|
||||
for layer in all_layers:
|
||||
if layer not in seen_layers:
|
||||
unique_layers.append(layer)
|
||||
# Track the Variable's identity to avoid __eq__ issues.
|
||||
seen_layers.add(layer)
|
||||
return unique_layers
|
||||
# Only instantiate set and deque if needed.
|
||||
layers_or_containers = getattr(self, '_layers', None)
|
||||
if layers_or_containers:
|
||||
seen_object_ids = set()
|
||||
deque = collections.deque(layers_or_containers)
|
||||
while deque:
|
||||
layer_or_container = deque.popleft()
|
||||
|
||||
def _gather_layers(self):
|
||||
"""Returns the current layer and all its children depth first."""
|
||||
all_layers = [self]
|
||||
if hasattr(self, '_layers'):
|
||||
child_layers = trackable_layer_utils.filter_empty_layer_containers(
|
||||
self._layers)
|
||||
for child_layer in child_layers:
|
||||
all_layers.extend(child_layer._gather_layers())
|
||||
return all_layers
|
||||
layer_or_container_id = id(layer_or_container)
|
||||
if layer_or_container_id in seen_object_ids:
|
||||
continue
|
||||
seen_object_ids.add(layer_or_container_id)
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _attribute_sentinel(self):
|
||||
return trackable_layer_utils.AttributeSentinel()
|
||||
if isinstance(layer_or_container, Layer):
|
||||
yield layer_or_container
|
||||
# Introspect recursively through sublayers.
|
||||
if recursive:
|
||||
sublayers = getattr(layer_or_container, '_layers', None)
|
||||
if sublayers:
|
||||
deque.extendleft(reversed(sublayers))
|
||||
elif isinstance(layer_or_container,
|
||||
data_structures.TrackableDataStructure):
|
||||
# Data structures are introspected even with `recursive=False`.
|
||||
tracked_values = layer_or_container._values
|
||||
if tracked_values:
|
||||
deque.extendleft(reversed(tracked_values))
|
||||
|
||||
# This is a hack so that the is_layer (within
|
||||
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
||||
|
@ -829,22 +829,15 @@ class Layer(base_layer.Layer):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
||||
def dynamic(self):
|
||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
||||
# then this cache logic must be updated.
|
||||
return self._dynamic or any(layer.dynamic
|
||||
for layer in self._unique_sublayers())
|
||||
return any(layer._dynamic for layer in self._flatten_layers())
|
||||
|
||||
@property
|
||||
@doc_controls.do_not_generate_docs
|
||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
||||
def stateful(self):
|
||||
return self._stateful or any(
|
||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
||||
return any(layer._stateful for layer in self._flatten_layers())
|
||||
|
||||
@stateful.setter
|
||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
||||
def stateful(self, value):
|
||||
self._stateful = value
|
||||
|
||||
@ -916,7 +909,7 @@ class Layer(base_layer.Layer):
|
||||
@property
|
||||
def updates(self):
|
||||
collected_updates = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
all_layers = self._flatten_layers()
|
||||
with backend.get_graph().as_default():
|
||||
for layer in all_layers:
|
||||
if not layer.trainable and not layer.stateful:
|
||||
@ -945,7 +938,7 @@ class Layer(base_layer.Layer):
|
||||
A list of tensors.
|
||||
"""
|
||||
collected_losses = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
all_layers = self._flatten_layers()
|
||||
for layer in all_layers:
|
||||
# If any eager losses are present, we assume the model to be part of an
|
||||
# eager training loop (either a custom one or the one used when
|
||||
@ -1075,8 +1068,7 @@ class Layer(base_layer.Layer):
|
||||
@property
|
||||
def metrics(self):
|
||||
collected_metrics = []
|
||||
all_layers = self._gather_unique_layers()
|
||||
for layer in all_layers:
|
||||
for layer in self._flatten_layers():
|
||||
collected_metrics.extend(layer._metrics)
|
||||
return collected_metrics
|
||||
|
||||
@ -2187,7 +2179,6 @@ class Layer(base_layer.Layer):
|
||||
super(tracking.AutoTrackable, self).__setattr__(
|
||||
'_layers',
|
||||
[l for l in self._layers if l is not existing_value])
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
if isinstance(existing_value, tf_variables.Variable):
|
||||
super(tracking.AutoTrackable, self).__setattr__(
|
||||
'_trainable_weights',
|
||||
@ -2196,13 +2187,6 @@ class Layer(base_layer.Layer):
|
||||
'_non_trainable_weights',
|
||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||
|
||||
# Any time we change `_layers` (either by deleting the attribute or by
|
||||
# reassigning it which will call __delattr__ from __setattr__) the topology
|
||||
# of the subgraph of Layers may change. In that case we will need to
|
||||
# recompute any attribute which depends on that subgraph.
|
||||
if name == '_layers':
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (name == '_self_setattr_tracking' or
|
||||
not getattr(self, '_self_setattr_tracking', True) or
|
||||
@ -2247,8 +2231,6 @@ class Layer(base_layer.Layer):
|
||||
# container types which compare equal.
|
||||
if not any((layer is value for layer in self._layers)):
|
||||
self._layers.append(value)
|
||||
if hasattr(value, '_attribute_sentinel'):
|
||||
value._attribute_sentinel.add_parent(self._attribute_sentinel)
|
||||
if hasattr(value, '_use_resource_variables'):
|
||||
# Legacy layers (V1 tf.layers) must always use
|
||||
# resource variables.
|
||||
@ -2296,35 +2278,6 @@ class Layer(base_layer.Layer):
|
||||
getattr(layer, attribute) for layer in nested_layers))
|
||||
return []
|
||||
|
||||
def _gather_unique_layers(self):
|
||||
"""Returns the current layer and all its children depth first deduped.
|
||||
|
||||
We are deduping after getting the layers to maintain the order.
|
||||
"""
|
||||
all_layers = self._gather_layers()
|
||||
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
|
||||
for layer in all_layers:
|
||||
if layer not in seen_layers:
|
||||
unique_layers.append(layer)
|
||||
# Track the Variable's identity to avoid __eq__ issues.
|
||||
seen_layers.add(layer)
|
||||
return unique_layers
|
||||
|
||||
def _gather_layers(self):
|
||||
"""Returns the current layer and all its children depth first."""
|
||||
all_layers = [self]
|
||||
if hasattr(self, '_layers'):
|
||||
child_layers = trackable_layer_utils.filter_empty_layer_containers(
|
||||
self._layers)
|
||||
for child_layer in child_layers:
|
||||
all_layers.extend(child_layer._gather_layers())
|
||||
return all_layers
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _attribute_sentinel(self):
|
||||
return trackable_layer_utils.AttributeSentinel()
|
||||
|
||||
# This is a hack so that the is_layer (within
|
||||
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
||||
# TODO(b/110718070): Remove when fixed.
|
||||
|
@ -193,7 +193,6 @@ class Functional(training_lib.Model):
|
||||
self._layer_call_argspecs = {}
|
||||
for layer in self._layers:
|
||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
||||
|
||||
# Build self.input_names and self.output_names.
|
||||
self._set_output_names()
|
||||
@ -731,10 +730,6 @@ class Functional(training_lib.Model):
|
||||
self._layers.append(layer)
|
||||
deferred_layers.append(layer)
|
||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||
|
||||
# This allows the added layer to broadcast mutations to the current
|
||||
# layer, which is necessary to ensure cache correctness.
|
||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
||||
layer_set.add(layer)
|
||||
self._handle_deferred_layer_dependencies(deferred_layers)
|
||||
|
||||
|
@ -188,10 +188,6 @@ class Sequential(functional.Functional):
|
||||
' of a layer in this model. Update the `name` argument '
|
||||
'to pass a unique name.' % (layer.name,))
|
||||
|
||||
# This allows the added layer to broadcast mutations to the current
|
||||
# layer, which is necessary to ensure cache correctness.
|
||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
||||
|
||||
self.built = False
|
||||
set_inputs = False
|
||||
if not self._layers:
|
||||
@ -236,9 +232,6 @@ class Sequential(functional.Functional):
|
||||
self._handle_deferred_layer_dependencies([layer])
|
||||
|
||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||
# Different Model types add to `._layers` in different ways, so for safety
|
||||
# we do a cache invalidation to make sure the changes are reflected.
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def pop(self):
|
||||
@ -252,7 +245,6 @@ class Sequential(functional.Functional):
|
||||
|
||||
layer = self._layers.pop()
|
||||
self._layer_call_argspecs.pop(layer)
|
||||
self._attribute_sentinel.invalidate_all()
|
||||
if not self.layers:
|
||||
self.outputs = None
|
||||
self.inputs = None
|
||||
|
@ -628,8 +628,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
if self.compiled_metrics is not None:
|
||||
metrics += self.compiled_metrics.metrics
|
||||
|
||||
all_layers = self._gather_unique_layers()
|
||||
for l in all_layers:
|
||||
for l in self._flatten_layers():
|
||||
metrics.extend(l._metrics) # pylint: disable=protected-access
|
||||
return metrics
|
||||
|
||||
@ -2310,7 +2309,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self._unique_sublayers()
|
||||
return list(self._flatten_layers(include_self=False, recursive=False))
|
||||
|
||||
def get_layer(self, name=None, index=None):
|
||||
"""Retrieves a layer based on either its name (unique) or index.
|
||||
|
@ -1499,6 +1499,65 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
new_kernel = model.get_weights()[1]
|
||||
self.assertNotAllEqual(old_kernel, new_kernel)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_layer_ordering(self):
|
||||
|
||||
class MyLayer(layers_module.Layer):
|
||||
pass
|
||||
|
||||
class MyModel(training_module.Model):
|
||||
|
||||
def __init__(self, name):
|
||||
super(MyModel, self).__init__(name=name)
|
||||
|
||||
self.weight = variables_lib.Variable(0, name=name)
|
||||
|
||||
self.direct_sublayer = MyLayer(name='direct')
|
||||
self.direct_sublayer.d = {'d': MyLayer(name='direct/dict')}
|
||||
|
||||
self.dict_sublayer = {'d': MyLayer(name='dict')}
|
||||
self.dict_sublayer['d'].direct = MyLayer(name='dict/direct')
|
||||
|
||||
model = MyModel('model')
|
||||
# All sublayers, including self and recursive sublayers.
|
||||
self.assertEqual(['model', 'direct', 'direct/dict', 'dict', 'dict/direct'],
|
||||
[l.name for l in model._flatten_layers()])
|
||||
# Only direct sublayers, including those in data structures.
|
||||
self.assertEqual(['direct', 'dict'], [l.name for l in model.layers])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_trainable_state_setting(self):
|
||||
|
||||
class UpdateLayer(layers_module.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(UpdateLayer, self).__init__()
|
||||
self.v = variables_lib.Variable(0., trainable=False)
|
||||
|
||||
def call(self, x):
|
||||
self.add_update(lambda: self.v.assign_add(1.))
|
||||
return x * self.v
|
||||
|
||||
layer = UpdateLayer()
|
||||
model_with_updates = sequential.Sequential([layer])
|
||||
model_with_updates.compile(
|
||||
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
layer.trainable = False
|
||||
model_without_updates = sequential.Sequential([layer])
|
||||
model_without_updates.compile(
|
||||
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||
|
||||
self.assertEqual(self.evaluate(layer.v), 0.)
|
||||
model_with_updates.fit(x, y, batch_size=10)
|
||||
# assign_add called.
|
||||
self.assertEqual(self.evaluate(layer.v), 1.)
|
||||
model_without_updates.fit(x, y, batch_size=10)
|
||||
# assign_add not called.
|
||||
self.assertEqual(self.evaluate(layer.v), 1.)
|
||||
|
||||
|
||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -505,11 +505,6 @@ def _in_place_subclassed_model_reset(model):
|
||||
setattr(model, name, fresh_layer)
|
||||
model._layers.append(fresh_layer)
|
||||
|
||||
# The base Layer __setattr__ will invalidate its attribute cache when
|
||||
# `._layers` is assigned, but it has no way to know when the underlying list
|
||||
# is mutated so we must explicitly signal the append.
|
||||
model._attribute_sentinel.invalidate_all()
|
||||
|
||||
# Cache original model build attributes (in addition to layers)
|
||||
if (not hasattr(model, '_original_attributes_cache') or
|
||||
model._original_attributes_cache is None):
|
||||
|
@ -79,7 +79,7 @@ def use_wrapped_call(layer, call_fn, default_training_value=None,
|
||||
# child layers. This causes `.losses` to only return eager losses.
|
||||
# pylint: disable=protected-access
|
||||
if context.executing_eagerly():
|
||||
for i in layer._gather_unique_layers():
|
||||
for i in layer._flatten_layers():
|
||||
if i is not layer:
|
||||
i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
|
||||
# pylint: enable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user