Reduce 1-layer Functional.__call__ overhead by ~10%

Adds a Layer._flatten_layers(recursive=True, include_self=True) method.
Uses this method for gathering unique sublayers, maintaining backwards
compatible ordering.

Removes unnecessary attribute caching for stateful and dynamic properties.

PiperOrigin-RevId: 314229252
Change-Id: I08cb80ae27861c52eae1ebed068b9a10d803e8a0
This commit is contained in:
Thomas O'Malley 2020-06-01 16:46:21 -07:00 committed by TensorFlower Gardener
parent edeae9fb69
commit b9d99cb6f5
8 changed files with 112 additions and 151 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import itertools
@ -1019,31 +1020,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return self._name
@property
@trackable_layer_utils.cache_recursive_attribute('dynamic')
def dynamic(self):
"""Whether the layer is dynamic (eager-only); set in the constructor."""
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
# then this cache logic must be updated.
return self._dynamic or any(layer.dynamic
for layer in self._unique_sublayers())
def _unique_sublayers(self):
# Model.layers will use this as implementation, but we can't expose this
# one as the public property since it might conflict with subclass layers
# which also have user defined layers property.
self._maybe_create_attribute('_layers', [])
return list(
trackable_layer_utils.filter_empty_layer_containers(self._layers))
return any(layer._dynamic for layer in self._flatten_layers())
@property
@doc_controls.do_not_doc_inheritable
@trackable_layer_utils.cache_recursive_attribute('stateful')
def stateful(self):
return self._stateful or any(
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
return any(layer._stateful for layer in self._flatten_layers())
@stateful.setter
@trackable_layer_utils.invalidate_recursive_cache('stateful')
def stateful(self, value):
self._stateful = value
@ -1053,9 +1039,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@trainable.setter
def trainable(self, value):
self._trainable = value
for layer in getattr(self, '_layers', []):
layer.trainable = value
for layer in self._flatten_layers():
layer._trainable = value
@property
def activity_regularizer(self):
@ -1162,7 +1147,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@doc_controls.do_not_doc_inheritable
def updates(self):
collected_updates = []
all_layers = self._gather_unique_layers()
all_layers = self._flatten_layers()
with backend.get_graph().as_default():
for layer in all_layers:
if not layer.trainable and not layer.stateful:
@ -1215,8 +1200,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
A list of tensors.
"""
collected_losses = []
all_layers = self._gather_unique_layers()
for layer in all_layers:
for layer in self._flatten_layers():
# If any eager losses are present, we assume the model to be part of an
# eager training loop (either a custom one or the one used when
# `run_eagerly=True`) and so we always return just the eager losses.
@ -1357,12 +1341,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def _clear_losses(self):
"""Used every step in eager to reset losses."""
# Set to thread local directly to avoid Layer.__setattr__ overhead.
self._thread_local._eager_losses = []
sublayers = getattr(self, '_layers', [])
if sublayers:
sublayers = trackable_layer_utils.filter_empty_layer_containers(sublayers)
for layer in sublayers:
layer._clear_losses()
if not getattr(self, '_layers', None): # Fast path for single Layer.
self._thread_local._eager_losses = []
else:
for layer in self._flatten_layers():
layer._thread_local._eager_losses = []
@property
def metrics(self):
@ -1382,8 +1365,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
A list of tensors.
"""
collected_metrics = []
all_layers = self._gather_unique_layers()
for layer in all_layers:
for layer in self._flatten_layers():
with layer._metrics_lock:
collected_metrics.extend(layer._metrics)
return collected_metrics
@ -2507,22 +2489,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Returns:
A dict mapping all sublayers to their `trainable` value.
"""
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
# Keep track of each top-level layers' `trainable` as well as the
# state of all of its sublayers.
trainable_state = weakref.WeakKeyDictionary()
trainable_state[self] = self.trainable
for layer in layers:
trainable_state.update(layer._get_trainable_state())
for layer in self._flatten_layers():
trainable_state[layer] = layer.trainable
return trainable_state
def _set_trainable_state(self, trainable_state):
"""Set `trainable` state for each sublayer."""
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
if self in trainable_state:
self.trainable = trainable_state[self]
for layer in layers:
layer._set_trainable_state(trainable_state)
for layer in self._flatten_layers():
if layer in trainable_state:
layer.trainable = trainable_state[layer]
@property
def _obj_reference_counts(self):
@ -2582,7 +2558,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
super(tracking.AutoTrackable, self).__setattr__(
'_layers',
[l for l in self._layers if l is not existing_value])
self._attribute_sentinel.invalidate_all()
if isinstance(existing_value, tf_variables.Variable):
super(tracking.AutoTrackable, self).__setattr__(
'_trainable_weights',
@ -2591,13 +2566,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
'_non_trainable_weights',
[w for w in self._non_trainable_weights if w is not existing_value])
# Any time we change `_layers` (either by deleting the attribute or by
# reassigning it which will call __delattr__ from __setattr__) the topology
# of the subgraph of Layers may change. In that case we will need to
# recompute any attribute which depends on that subgraph.
if name == '_layers':
self._attribute_sentinel.invalidate_all()
def __setattr__(self, name, value):
if (name == '_self_setattr_tracking' or
not getattr(self, '_self_setattr_tracking', True) or
@ -2642,8 +2610,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# container types which compare equal.
if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_attribute_sentinel'):
value._attribute_sentinel.add_parent(self._attribute_sentinel)
if hasattr(value, '_use_resource_variables'):
# Legacy layers (V1 tf.layers) must always use
# resource variables.
@ -2691,34 +2657,36 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
getattr(layer, attribute) for layer in nested_layers))
return []
def _gather_unique_layers(self):
"""Returns the current layer and all its children depth first deduped.
def _flatten_layers(self, recursive=True, include_self=True):
if include_self:
yield self
We are deduping after getting the layers to maintain the order.
"""
all_layers = self._gather_layers()
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
for layer in all_layers:
if layer not in seen_layers:
unique_layers.append(layer)
# Track the Variable's identity to avoid __eq__ issues.
seen_layers.add(layer)
return unique_layers
# Only instantiate set and deque if needed.
layers_or_containers = getattr(self, '_layers', None)
if layers_or_containers:
seen_object_ids = set()
deque = collections.deque(layers_or_containers)
while deque:
layer_or_container = deque.popleft()
def _gather_layers(self):
"""Returns the current layer and all its children depth first."""
all_layers = [self]
if hasattr(self, '_layers'):
child_layers = trackable_layer_utils.filter_empty_layer_containers(
self._layers)
for child_layer in child_layers:
all_layers.extend(child_layer._gather_layers())
return all_layers
layer_or_container_id = id(layer_or_container)
if layer_or_container_id in seen_object_ids:
continue
seen_object_ids.add(layer_or_container_id)
@property
@tracking.cached_per_instance
def _attribute_sentinel(self):
return trackable_layer_utils.AttributeSentinel()
if isinstance(layer_or_container, Layer):
yield layer_or_container
# Introspect recursively through sublayers.
if recursive:
sublayers = getattr(layer_or_container, '_layers', None)
if sublayers:
deque.extendleft(reversed(sublayers))
elif isinstance(layer_or_container,
data_structures.TrackableDataStructure):
# Data structures are introspected even with `recursive=False`.
tracked_values = layer_or_container._values
if tracked_values:
deque.extendleft(reversed(tracked_values))
# This is a hack so that the is_layer (within
# training/trackable/layer_utils.py) check doesn't get the weights attr.

View File

@ -829,22 +829,15 @@ class Layer(base_layer.Layer):
return self._name
@property
@trackable_layer_utils.cache_recursive_attribute('dynamic')
def dynamic(self):
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
# then this cache logic must be updated.
return self._dynamic or any(layer.dynamic
for layer in self._unique_sublayers())
return any(layer._dynamic for layer in self._flatten_layers())
@property
@doc_controls.do_not_generate_docs
@trackable_layer_utils.cache_recursive_attribute('stateful')
def stateful(self):
return self._stateful or any(
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
return any(layer._stateful for layer in self._flatten_layers())
@stateful.setter
@trackable_layer_utils.invalidate_recursive_cache('stateful')
def stateful(self, value):
self._stateful = value
@ -916,7 +909,7 @@ class Layer(base_layer.Layer):
@property
def updates(self):
collected_updates = []
all_layers = self._gather_unique_layers()
all_layers = self._flatten_layers()
with backend.get_graph().as_default():
for layer in all_layers:
if not layer.trainable and not layer.stateful:
@ -945,7 +938,7 @@ class Layer(base_layer.Layer):
A list of tensors.
"""
collected_losses = []
all_layers = self._gather_unique_layers()
all_layers = self._flatten_layers()
for layer in all_layers:
# If any eager losses are present, we assume the model to be part of an
# eager training loop (either a custom one or the one used when
@ -1075,8 +1068,7 @@ class Layer(base_layer.Layer):
@property
def metrics(self):
collected_metrics = []
all_layers = self._gather_unique_layers()
for layer in all_layers:
for layer in self._flatten_layers():
collected_metrics.extend(layer._metrics)
return collected_metrics
@ -2187,7 +2179,6 @@ class Layer(base_layer.Layer):
super(tracking.AutoTrackable, self).__setattr__(
'_layers',
[l for l in self._layers if l is not existing_value])
self._attribute_sentinel.invalidate_all()
if isinstance(existing_value, tf_variables.Variable):
super(tracking.AutoTrackable, self).__setattr__(
'_trainable_weights',
@ -2196,13 +2187,6 @@ class Layer(base_layer.Layer):
'_non_trainable_weights',
[w for w in self._non_trainable_weights if w is not existing_value])
# Any time we change `_layers` (either by deleting the attribute or by
# reassigning it which will call __delattr__ from __setattr__) the topology
# of the subgraph of Layers may change. In that case we will need to
# recompute any attribute which depends on that subgraph.
if name == '_layers':
self._attribute_sentinel.invalidate_all()
def __setattr__(self, name, value):
if (name == '_self_setattr_tracking' or
not getattr(self, '_self_setattr_tracking', True) or
@ -2247,8 +2231,6 @@ class Layer(base_layer.Layer):
# container types which compare equal.
if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_attribute_sentinel'):
value._attribute_sentinel.add_parent(self._attribute_sentinel)
if hasattr(value, '_use_resource_variables'):
# Legacy layers (V1 tf.layers) must always use
# resource variables.
@ -2296,35 +2278,6 @@ class Layer(base_layer.Layer):
getattr(layer, attribute) for layer in nested_layers))
return []
def _gather_unique_layers(self):
"""Returns the current layer and all its children depth first deduped.
We are deduping after getting the layers to maintain the order.
"""
all_layers = self._gather_layers()
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
for layer in all_layers:
if layer not in seen_layers:
unique_layers.append(layer)
# Track the Variable's identity to avoid __eq__ issues.
seen_layers.add(layer)
return unique_layers
def _gather_layers(self):
"""Returns the current layer and all its children depth first."""
all_layers = [self]
if hasattr(self, '_layers'):
child_layers = trackable_layer_utils.filter_empty_layer_containers(
self._layers)
for child_layer in child_layers:
all_layers.extend(child_layer._gather_layers())
return all_layers
@property
@tracking.cached_per_instance
def _attribute_sentinel(self):
return trackable_layer_utils.AttributeSentinel()
# This is a hack so that the is_layer (within
# training/trackable/layer_utils.py) check doesn't get the weights attr.
# TODO(b/110718070): Remove when fixed.

View File

@ -193,7 +193,6 @@ class Functional(training_lib.Model):
self._layer_call_argspecs = {}
for layer in self._layers:
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
# Build self.input_names and self.output_names.
self._set_output_names()
@ -731,10 +730,6 @@ class Functional(training_lib.Model):
self._layers.append(layer)
deferred_layers.append(layer)
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
# This allows the added layer to broadcast mutations to the current
# layer, which is necessary to ensure cache correctness.
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
layer_set.add(layer)
self._handle_deferred_layer_dependencies(deferred_layers)

View File

@ -188,10 +188,6 @@ class Sequential(functional.Functional):
' of a layer in this model. Update the `name` argument '
'to pass a unique name.' % (layer.name,))
# This allows the added layer to broadcast mutations to the current
# layer, which is necessary to ensure cache correctness.
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
self.built = False
set_inputs = False
if not self._layers:
@ -236,9 +232,6 @@ class Sequential(functional.Functional):
self._handle_deferred_layer_dependencies([layer])
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
# Different Model types add to `._layers` in different ways, so for safety
# we do a cache invalidation to make sure the changes are reflected.
self._attribute_sentinel.invalidate_all()
@trackable.no_automatic_dependency_tracking
def pop(self):
@ -252,7 +245,6 @@ class Sequential(functional.Functional):
layer = self._layers.pop()
self._layer_call_argspecs.pop(layer)
self._attribute_sentinel.invalidate_all()
if not self.layers:
self.outputs = None
self.inputs = None

View File

@ -628,8 +628,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
if self.compiled_metrics is not None:
metrics += self.compiled_metrics.metrics
all_layers = self._gather_unique_layers()
for l in all_layers:
for l in self._flatten_layers():
metrics.extend(l._metrics) # pylint: disable=protected-access
return metrics
@ -2310,7 +2309,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
@property
def layers(self):
return self._unique_sublayers()
return list(self._flatten_layers(include_self=False, recursive=False))
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.

View File

@ -1499,6 +1499,65 @@ class TrainingTest(keras_parameterized.TestCase):
new_kernel = model.get_weights()[1]
self.assertNotAllEqual(old_kernel, new_kernel)
@keras_parameterized.run_all_keras_modes
def test_layer_ordering(self):
class MyLayer(layers_module.Layer):
pass
class MyModel(training_module.Model):
def __init__(self, name):
super(MyModel, self).__init__(name=name)
self.weight = variables_lib.Variable(0, name=name)
self.direct_sublayer = MyLayer(name='direct')
self.direct_sublayer.d = {'d': MyLayer(name='direct/dict')}
self.dict_sublayer = {'d': MyLayer(name='dict')}
self.dict_sublayer['d'].direct = MyLayer(name='dict/direct')
model = MyModel('model')
# All sublayers, including self and recursive sublayers.
self.assertEqual(['model', 'direct', 'direct/dict', 'dict', 'dict/direct'],
[l.name for l in model._flatten_layers()])
# Only direct sublayers, including those in data structures.
self.assertEqual(['direct', 'dict'], [l.name for l in model.layers])
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_trainable_state_setting(self):
class UpdateLayer(layers_module.Layer):
def __init__(self):
super(UpdateLayer, self).__init__()
self.v = variables_lib.Variable(0., trainable=False)
def call(self, x):
self.add_update(lambda: self.v.assign_add(1.))
return x * self.v
layer = UpdateLayer()
model_with_updates = sequential.Sequential([layer])
model_with_updates.compile(
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
layer.trainable = False
model_without_updates = sequential.Sequential([layer])
model_without_updates.compile(
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
x, y = np.ones((10, 1)), np.ones((10, 1))
self.assertEqual(self.evaluate(layer.v), 0.)
model_with_updates.fit(x, y, batch_size=10)
# assign_add called.
self.assertEqual(self.evaluate(layer.v), 1.)
model_without_updates.fit(x, y, batch_size=10)
# assign_add not called.
self.assertEqual(self.evaluate(layer.v), 1.)
class TestExceptionsAndWarnings(keras_parameterized.TestCase):

View File

@ -505,11 +505,6 @@ def _in_place_subclassed_model_reset(model):
setattr(model, name, fresh_layer)
model._layers.append(fresh_layer)
# The base Layer __setattr__ will invalidate its attribute cache when
# `._layers` is assigned, but it has no way to know when the underlying list
# is mutated so we must explicitly signal the append.
model._attribute_sentinel.invalidate_all()
# Cache original model build attributes (in addition to layers)
if (not hasattr(model, '_original_attributes_cache') or
model._original_attributes_cache is None):

View File

@ -79,7 +79,7 @@ def use_wrapped_call(layer, call_fn, default_training_value=None,
# child layers. This causes `.losses` to only return eager losses.
# pylint: disable=protected-access
if context.executing_eagerly():
for i in layer._gather_unique_layers():
for i in layer._flatten_layers():
if i is not layer:
i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
# pylint: enable=protected-access