Reduce 1-layer Functional.__call__ overhead by ~10%
Adds a Layer._flatten_layers(recursive=True, include_self=True) method. Uses this method for gathering unique sublayers, maintaining backwards compatible ordering. Removes unnecessary attribute caching for stateful and dynamic properties. PiperOrigin-RevId: 314229252 Change-Id: I08cb80ae27861c52eae1ebed068b9a10d803e8a0
This commit is contained in:
parent
edeae9fb69
commit
b9d99cb6f5
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
@ -1019,31 +1020,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
|
||||||
def dynamic(self):
|
def dynamic(self):
|
||||||
"""Whether the layer is dynamic (eager-only); set in the constructor."""
|
"""Whether the layer is dynamic (eager-only); set in the constructor."""
|
||||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
return any(layer._dynamic for layer in self._flatten_layers())
|
||||||
# then this cache logic must be updated.
|
|
||||||
return self._dynamic or any(layer.dynamic
|
|
||||||
for layer in self._unique_sublayers())
|
|
||||||
|
|
||||||
def _unique_sublayers(self):
|
|
||||||
# Model.layers will use this as implementation, but we can't expose this
|
|
||||||
# one as the public property since it might conflict with subclass layers
|
|
||||||
# which also have user defined layers property.
|
|
||||||
self._maybe_create_attribute('_layers', [])
|
|
||||||
return list(
|
|
||||||
trackable_layer_utils.filter_empty_layer_containers(self._layers))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@doc_controls.do_not_doc_inheritable
|
@doc_controls.do_not_doc_inheritable
|
||||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
|
||||||
def stateful(self):
|
def stateful(self):
|
||||||
return self._stateful or any(
|
return any(layer._stateful for layer in self._flatten_layers())
|
||||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
|
||||||
|
|
||||||
@stateful.setter
|
@stateful.setter
|
||||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
|
||||||
def stateful(self, value):
|
def stateful(self, value):
|
||||||
self._stateful = value
|
self._stateful = value
|
||||||
|
|
||||||
@ -1053,9 +1039,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
|
|
||||||
@trainable.setter
|
@trainable.setter
|
||||||
def trainable(self, value):
|
def trainable(self, value):
|
||||||
self._trainable = value
|
for layer in self._flatten_layers():
|
||||||
for layer in getattr(self, '_layers', []):
|
layer._trainable = value
|
||||||
layer.trainable = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activity_regularizer(self):
|
def activity_regularizer(self):
|
||||||
@ -1162,7 +1147,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
@doc_controls.do_not_doc_inheritable
|
@doc_controls.do_not_doc_inheritable
|
||||||
def updates(self):
|
def updates(self):
|
||||||
collected_updates = []
|
collected_updates = []
|
||||||
all_layers = self._gather_unique_layers()
|
all_layers = self._flatten_layers()
|
||||||
with backend.get_graph().as_default():
|
with backend.get_graph().as_default():
|
||||||
for layer in all_layers:
|
for layer in all_layers:
|
||||||
if not layer.trainable and not layer.stateful:
|
if not layer.trainable and not layer.stateful:
|
||||||
@ -1215,8 +1200,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
A list of tensors.
|
A list of tensors.
|
||||||
"""
|
"""
|
||||||
collected_losses = []
|
collected_losses = []
|
||||||
all_layers = self._gather_unique_layers()
|
for layer in self._flatten_layers():
|
||||||
for layer in all_layers:
|
|
||||||
# If any eager losses are present, we assume the model to be part of an
|
# If any eager losses are present, we assume the model to be part of an
|
||||||
# eager training loop (either a custom one or the one used when
|
# eager training loop (either a custom one or the one used when
|
||||||
# `run_eagerly=True`) and so we always return just the eager losses.
|
# `run_eagerly=True`) and so we always return just the eager losses.
|
||||||
@ -1357,12 +1341,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
def _clear_losses(self):
|
def _clear_losses(self):
|
||||||
"""Used every step in eager to reset losses."""
|
"""Used every step in eager to reset losses."""
|
||||||
# Set to thread local directly to avoid Layer.__setattr__ overhead.
|
# Set to thread local directly to avoid Layer.__setattr__ overhead.
|
||||||
self._thread_local._eager_losses = []
|
if not getattr(self, '_layers', None): # Fast path for single Layer.
|
||||||
sublayers = getattr(self, '_layers', [])
|
self._thread_local._eager_losses = []
|
||||||
if sublayers:
|
else:
|
||||||
sublayers = trackable_layer_utils.filter_empty_layer_containers(sublayers)
|
for layer in self._flatten_layers():
|
||||||
for layer in sublayers:
|
layer._thread_local._eager_losses = []
|
||||||
layer._clear_losses()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metrics(self):
|
def metrics(self):
|
||||||
@ -1382,8 +1365,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
A list of tensors.
|
A list of tensors.
|
||||||
"""
|
"""
|
||||||
collected_metrics = []
|
collected_metrics = []
|
||||||
all_layers = self._gather_unique_layers()
|
for layer in self._flatten_layers():
|
||||||
for layer in all_layers:
|
|
||||||
with layer._metrics_lock:
|
with layer._metrics_lock:
|
||||||
collected_metrics.extend(layer._metrics)
|
collected_metrics.extend(layer._metrics)
|
||||||
return collected_metrics
|
return collected_metrics
|
||||||
@ -2507,22 +2489,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
Returns:
|
Returns:
|
||||||
A dict mapping all sublayers to their `trainable` value.
|
A dict mapping all sublayers to their `trainable` value.
|
||||||
"""
|
"""
|
||||||
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
|
|
||||||
# Keep track of each top-level layers' `trainable` as well as the
|
|
||||||
# state of all of its sublayers.
|
|
||||||
trainable_state = weakref.WeakKeyDictionary()
|
trainable_state = weakref.WeakKeyDictionary()
|
||||||
trainable_state[self] = self.trainable
|
for layer in self._flatten_layers():
|
||||||
for layer in layers:
|
trainable_state[layer] = layer.trainable
|
||||||
trainable_state.update(layer._get_trainable_state())
|
|
||||||
return trainable_state
|
return trainable_state
|
||||||
|
|
||||||
def _set_trainable_state(self, trainable_state):
|
def _set_trainable_state(self, trainable_state):
|
||||||
"""Set `trainable` state for each sublayer."""
|
"""Set `trainable` state for each sublayer."""
|
||||||
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
|
for layer in self._flatten_layers():
|
||||||
if self in trainable_state:
|
if layer in trainable_state:
|
||||||
self.trainable = trainable_state[self]
|
layer.trainable = trainable_state[layer]
|
||||||
for layer in layers:
|
|
||||||
layer._set_trainable_state(trainable_state)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _obj_reference_counts(self):
|
def _obj_reference_counts(self):
|
||||||
@ -2582,7 +2558,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
super(tracking.AutoTrackable, self).__setattr__(
|
super(tracking.AutoTrackable, self).__setattr__(
|
||||||
'_layers',
|
'_layers',
|
||||||
[l for l in self._layers if l is not existing_value])
|
[l for l in self._layers if l is not existing_value])
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
if isinstance(existing_value, tf_variables.Variable):
|
if isinstance(existing_value, tf_variables.Variable):
|
||||||
super(tracking.AutoTrackable, self).__setattr__(
|
super(tracking.AutoTrackable, self).__setattr__(
|
||||||
'_trainable_weights',
|
'_trainable_weights',
|
||||||
@ -2591,13 +2566,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
'_non_trainable_weights',
|
'_non_trainable_weights',
|
||||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||||
|
|
||||||
# Any time we change `_layers` (either by deleting the attribute or by
|
|
||||||
# reassigning it which will call __delattr__ from __setattr__) the topology
|
|
||||||
# of the subgraph of Layers may change. In that case we will need to
|
|
||||||
# recompute any attribute which depends on that subgraph.
|
|
||||||
if name == '_layers':
|
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if (name == '_self_setattr_tracking' or
|
if (name == '_self_setattr_tracking' or
|
||||||
not getattr(self, '_self_setattr_tracking', True) or
|
not getattr(self, '_self_setattr_tracking', True) or
|
||||||
@ -2642,8 +2610,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# container types which compare equal.
|
# container types which compare equal.
|
||||||
if not any((layer is value for layer in self._layers)):
|
if not any((layer is value for layer in self._layers)):
|
||||||
self._layers.append(value)
|
self._layers.append(value)
|
||||||
if hasattr(value, '_attribute_sentinel'):
|
|
||||||
value._attribute_sentinel.add_parent(self._attribute_sentinel)
|
|
||||||
if hasattr(value, '_use_resource_variables'):
|
if hasattr(value, '_use_resource_variables'):
|
||||||
# Legacy layers (V1 tf.layers) must always use
|
# Legacy layers (V1 tf.layers) must always use
|
||||||
# resource variables.
|
# resource variables.
|
||||||
@ -2691,34 +2657,36 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
getattr(layer, attribute) for layer in nested_layers))
|
getattr(layer, attribute) for layer in nested_layers))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _gather_unique_layers(self):
|
def _flatten_layers(self, recursive=True, include_self=True):
|
||||||
"""Returns the current layer and all its children depth first deduped.
|
if include_self:
|
||||||
|
yield self
|
||||||
|
|
||||||
We are deduping after getting the layers to maintain the order.
|
# Only instantiate set and deque if needed.
|
||||||
"""
|
layers_or_containers = getattr(self, '_layers', None)
|
||||||
all_layers = self._gather_layers()
|
if layers_or_containers:
|
||||||
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
|
seen_object_ids = set()
|
||||||
for layer in all_layers:
|
deque = collections.deque(layers_or_containers)
|
||||||
if layer not in seen_layers:
|
while deque:
|
||||||
unique_layers.append(layer)
|
layer_or_container = deque.popleft()
|
||||||
# Track the Variable's identity to avoid __eq__ issues.
|
|
||||||
seen_layers.add(layer)
|
|
||||||
return unique_layers
|
|
||||||
|
|
||||||
def _gather_layers(self):
|
layer_or_container_id = id(layer_or_container)
|
||||||
"""Returns the current layer and all its children depth first."""
|
if layer_or_container_id in seen_object_ids:
|
||||||
all_layers = [self]
|
continue
|
||||||
if hasattr(self, '_layers'):
|
seen_object_ids.add(layer_or_container_id)
|
||||||
child_layers = trackable_layer_utils.filter_empty_layer_containers(
|
|
||||||
self._layers)
|
|
||||||
for child_layer in child_layers:
|
|
||||||
all_layers.extend(child_layer._gather_layers())
|
|
||||||
return all_layers
|
|
||||||
|
|
||||||
@property
|
if isinstance(layer_or_container, Layer):
|
||||||
@tracking.cached_per_instance
|
yield layer_or_container
|
||||||
def _attribute_sentinel(self):
|
# Introspect recursively through sublayers.
|
||||||
return trackable_layer_utils.AttributeSentinel()
|
if recursive:
|
||||||
|
sublayers = getattr(layer_or_container, '_layers', None)
|
||||||
|
if sublayers:
|
||||||
|
deque.extendleft(reversed(sublayers))
|
||||||
|
elif isinstance(layer_or_container,
|
||||||
|
data_structures.TrackableDataStructure):
|
||||||
|
# Data structures are introspected even with `recursive=False`.
|
||||||
|
tracked_values = layer_or_container._values
|
||||||
|
if tracked_values:
|
||||||
|
deque.extendleft(reversed(tracked_values))
|
||||||
|
|
||||||
# This is a hack so that the is_layer (within
|
# This is a hack so that the is_layer (within
|
||||||
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
||||||
|
@ -829,22 +829,15 @@ class Layer(base_layer.Layer):
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
|
||||||
def dynamic(self):
|
def dynamic(self):
|
||||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
return any(layer._dynamic for layer in self._flatten_layers())
|
||||||
# then this cache logic must be updated.
|
|
||||||
return self._dynamic or any(layer.dynamic
|
|
||||||
for layer in self._unique_sublayers())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
|
||||||
def stateful(self):
|
def stateful(self):
|
||||||
return self._stateful or any(
|
return any(layer._stateful for layer in self._flatten_layers())
|
||||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
|
||||||
|
|
||||||
@stateful.setter
|
@stateful.setter
|
||||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
|
||||||
def stateful(self, value):
|
def stateful(self, value):
|
||||||
self._stateful = value
|
self._stateful = value
|
||||||
|
|
||||||
@ -916,7 +909,7 @@ class Layer(base_layer.Layer):
|
|||||||
@property
|
@property
|
||||||
def updates(self):
|
def updates(self):
|
||||||
collected_updates = []
|
collected_updates = []
|
||||||
all_layers = self._gather_unique_layers()
|
all_layers = self._flatten_layers()
|
||||||
with backend.get_graph().as_default():
|
with backend.get_graph().as_default():
|
||||||
for layer in all_layers:
|
for layer in all_layers:
|
||||||
if not layer.trainable and not layer.stateful:
|
if not layer.trainable and not layer.stateful:
|
||||||
@ -945,7 +938,7 @@ class Layer(base_layer.Layer):
|
|||||||
A list of tensors.
|
A list of tensors.
|
||||||
"""
|
"""
|
||||||
collected_losses = []
|
collected_losses = []
|
||||||
all_layers = self._gather_unique_layers()
|
all_layers = self._flatten_layers()
|
||||||
for layer in all_layers:
|
for layer in all_layers:
|
||||||
# If any eager losses are present, we assume the model to be part of an
|
# If any eager losses are present, we assume the model to be part of an
|
||||||
# eager training loop (either a custom one or the one used when
|
# eager training loop (either a custom one or the one used when
|
||||||
@ -1075,8 +1068,7 @@ class Layer(base_layer.Layer):
|
|||||||
@property
|
@property
|
||||||
def metrics(self):
|
def metrics(self):
|
||||||
collected_metrics = []
|
collected_metrics = []
|
||||||
all_layers = self._gather_unique_layers()
|
for layer in self._flatten_layers():
|
||||||
for layer in all_layers:
|
|
||||||
collected_metrics.extend(layer._metrics)
|
collected_metrics.extend(layer._metrics)
|
||||||
return collected_metrics
|
return collected_metrics
|
||||||
|
|
||||||
@ -2187,7 +2179,6 @@ class Layer(base_layer.Layer):
|
|||||||
super(tracking.AutoTrackable, self).__setattr__(
|
super(tracking.AutoTrackable, self).__setattr__(
|
||||||
'_layers',
|
'_layers',
|
||||||
[l for l in self._layers if l is not existing_value])
|
[l for l in self._layers if l is not existing_value])
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
if isinstance(existing_value, tf_variables.Variable):
|
if isinstance(existing_value, tf_variables.Variable):
|
||||||
super(tracking.AutoTrackable, self).__setattr__(
|
super(tracking.AutoTrackable, self).__setattr__(
|
||||||
'_trainable_weights',
|
'_trainable_weights',
|
||||||
@ -2196,13 +2187,6 @@ class Layer(base_layer.Layer):
|
|||||||
'_non_trainable_weights',
|
'_non_trainable_weights',
|
||||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||||
|
|
||||||
# Any time we change `_layers` (either by deleting the attribute or by
|
|
||||||
# reassigning it which will call __delattr__ from __setattr__) the topology
|
|
||||||
# of the subgraph of Layers may change. In that case we will need to
|
|
||||||
# recompute any attribute which depends on that subgraph.
|
|
||||||
if name == '_layers':
|
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if (name == '_self_setattr_tracking' or
|
if (name == '_self_setattr_tracking' or
|
||||||
not getattr(self, '_self_setattr_tracking', True) or
|
not getattr(self, '_self_setattr_tracking', True) or
|
||||||
@ -2247,8 +2231,6 @@ class Layer(base_layer.Layer):
|
|||||||
# container types which compare equal.
|
# container types which compare equal.
|
||||||
if not any((layer is value for layer in self._layers)):
|
if not any((layer is value for layer in self._layers)):
|
||||||
self._layers.append(value)
|
self._layers.append(value)
|
||||||
if hasattr(value, '_attribute_sentinel'):
|
|
||||||
value._attribute_sentinel.add_parent(self._attribute_sentinel)
|
|
||||||
if hasattr(value, '_use_resource_variables'):
|
if hasattr(value, '_use_resource_variables'):
|
||||||
# Legacy layers (V1 tf.layers) must always use
|
# Legacy layers (V1 tf.layers) must always use
|
||||||
# resource variables.
|
# resource variables.
|
||||||
@ -2296,35 +2278,6 @@ class Layer(base_layer.Layer):
|
|||||||
getattr(layer, attribute) for layer in nested_layers))
|
getattr(layer, attribute) for layer in nested_layers))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _gather_unique_layers(self):
|
|
||||||
"""Returns the current layer and all its children depth first deduped.
|
|
||||||
|
|
||||||
We are deduping after getting the layers to maintain the order.
|
|
||||||
"""
|
|
||||||
all_layers = self._gather_layers()
|
|
||||||
unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
|
|
||||||
for layer in all_layers:
|
|
||||||
if layer not in seen_layers:
|
|
||||||
unique_layers.append(layer)
|
|
||||||
# Track the Variable's identity to avoid __eq__ issues.
|
|
||||||
seen_layers.add(layer)
|
|
||||||
return unique_layers
|
|
||||||
|
|
||||||
def _gather_layers(self):
|
|
||||||
"""Returns the current layer and all its children depth first."""
|
|
||||||
all_layers = [self]
|
|
||||||
if hasattr(self, '_layers'):
|
|
||||||
child_layers = trackable_layer_utils.filter_empty_layer_containers(
|
|
||||||
self._layers)
|
|
||||||
for child_layer in child_layers:
|
|
||||||
all_layers.extend(child_layer._gather_layers())
|
|
||||||
return all_layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
@tracking.cached_per_instance
|
|
||||||
def _attribute_sentinel(self):
|
|
||||||
return trackable_layer_utils.AttributeSentinel()
|
|
||||||
|
|
||||||
# This is a hack so that the is_layer (within
|
# This is a hack so that the is_layer (within
|
||||||
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
||||||
# TODO(b/110718070): Remove when fixed.
|
# TODO(b/110718070): Remove when fixed.
|
||||||
|
@ -193,7 +193,6 @@ class Functional(training_lib.Model):
|
|||||||
self._layer_call_argspecs = {}
|
self._layer_call_argspecs = {}
|
||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
|
||||||
|
|
||||||
# Build self.input_names and self.output_names.
|
# Build self.input_names and self.output_names.
|
||||||
self._set_output_names()
|
self._set_output_names()
|
||||||
@ -731,10 +730,6 @@ class Functional(training_lib.Model):
|
|||||||
self._layers.append(layer)
|
self._layers.append(layer)
|
||||||
deferred_layers.append(layer)
|
deferred_layers.append(layer)
|
||||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||||
|
|
||||||
# This allows the added layer to broadcast mutations to the current
|
|
||||||
# layer, which is necessary to ensure cache correctness.
|
|
||||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
|
||||||
layer_set.add(layer)
|
layer_set.add(layer)
|
||||||
self._handle_deferred_layer_dependencies(deferred_layers)
|
self._handle_deferred_layer_dependencies(deferred_layers)
|
||||||
|
|
||||||
|
@ -188,10 +188,6 @@ class Sequential(functional.Functional):
|
|||||||
' of a layer in this model. Update the `name` argument '
|
' of a layer in this model. Update the `name` argument '
|
||||||
'to pass a unique name.' % (layer.name,))
|
'to pass a unique name.' % (layer.name,))
|
||||||
|
|
||||||
# This allows the added layer to broadcast mutations to the current
|
|
||||||
# layer, which is necessary to ensure cache correctness.
|
|
||||||
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
|
|
||||||
|
|
||||||
self.built = False
|
self.built = False
|
||||||
set_inputs = False
|
set_inputs = False
|
||||||
if not self._layers:
|
if not self._layers:
|
||||||
@ -236,9 +232,6 @@ class Sequential(functional.Functional):
|
|||||||
self._handle_deferred_layer_dependencies([layer])
|
self._handle_deferred_layer_dependencies([layer])
|
||||||
|
|
||||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||||
# Different Model types add to `._layers` in different ways, so for safety
|
|
||||||
# we do a cache invalidation to make sure the changes are reflected.
|
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
|
|
||||||
@trackable.no_automatic_dependency_tracking
|
@trackable.no_automatic_dependency_tracking
|
||||||
def pop(self):
|
def pop(self):
|
||||||
@ -252,7 +245,6 @@ class Sequential(functional.Functional):
|
|||||||
|
|
||||||
layer = self._layers.pop()
|
layer = self._layers.pop()
|
||||||
self._layer_call_argspecs.pop(layer)
|
self._layer_call_argspecs.pop(layer)
|
||||||
self._attribute_sentinel.invalidate_all()
|
|
||||||
if not self.layers:
|
if not self.layers:
|
||||||
self.outputs = None
|
self.outputs = None
|
||||||
self.inputs = None
|
self.inputs = None
|
||||||
|
@ -628,8 +628,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
if self.compiled_metrics is not None:
|
if self.compiled_metrics is not None:
|
||||||
metrics += self.compiled_metrics.metrics
|
metrics += self.compiled_metrics.metrics
|
||||||
|
|
||||||
all_layers = self._gather_unique_layers()
|
for l in self._flatten_layers():
|
||||||
for l in all_layers:
|
|
||||||
metrics.extend(l._metrics) # pylint: disable=protected-access
|
metrics.extend(l._metrics) # pylint: disable=protected-access
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@ -2310,7 +2309,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self._unique_sublayers()
|
return list(self._flatten_layers(include_self=False, recursive=False))
|
||||||
|
|
||||||
def get_layer(self, name=None, index=None):
|
def get_layer(self, name=None, index=None):
|
||||||
"""Retrieves a layer based on either its name (unique) or index.
|
"""Retrieves a layer based on either its name (unique) or index.
|
||||||
|
@ -1499,6 +1499,65 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
new_kernel = model.get_weights()[1]
|
new_kernel = model.get_weights()[1]
|
||||||
self.assertNotAllEqual(old_kernel, new_kernel)
|
self.assertNotAllEqual(old_kernel, new_kernel)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_layer_ordering(self):
|
||||||
|
|
||||||
|
class MyLayer(layers_module.Layer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MyModel(training_module.Model):
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
super(MyModel, self).__init__(name=name)
|
||||||
|
|
||||||
|
self.weight = variables_lib.Variable(0, name=name)
|
||||||
|
|
||||||
|
self.direct_sublayer = MyLayer(name='direct')
|
||||||
|
self.direct_sublayer.d = {'d': MyLayer(name='direct/dict')}
|
||||||
|
|
||||||
|
self.dict_sublayer = {'d': MyLayer(name='dict')}
|
||||||
|
self.dict_sublayer['d'].direct = MyLayer(name='dict/direct')
|
||||||
|
|
||||||
|
model = MyModel('model')
|
||||||
|
# All sublayers, including self and recursive sublayers.
|
||||||
|
self.assertEqual(['model', 'direct', 'direct/dict', 'dict', 'dict/direct'],
|
||||||
|
[l.name for l in model._flatten_layers()])
|
||||||
|
# Only direct sublayers, including those in data structures.
|
||||||
|
self.assertEqual(['direct', 'dict'], [l.name for l in model.layers])
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
def test_trainable_state_setting(self):
|
||||||
|
|
||||||
|
class UpdateLayer(layers_module.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(UpdateLayer, self).__init__()
|
||||||
|
self.v = variables_lib.Variable(0., trainable=False)
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
self.add_update(lambda: self.v.assign_add(1.))
|
||||||
|
return x * self.v
|
||||||
|
|
||||||
|
layer = UpdateLayer()
|
||||||
|
model_with_updates = sequential.Sequential([layer])
|
||||||
|
model_with_updates.compile(
|
||||||
|
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
|
layer.trainable = False
|
||||||
|
model_without_updates = sequential.Sequential([layer])
|
||||||
|
model_without_updates.compile(
|
||||||
|
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
|
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||||
|
|
||||||
|
self.assertEqual(self.evaluate(layer.v), 0.)
|
||||||
|
model_with_updates.fit(x, y, batch_size=10)
|
||||||
|
# assign_add called.
|
||||||
|
self.assertEqual(self.evaluate(layer.v), 1.)
|
||||||
|
model_without_updates.fit(x, y, batch_size=10)
|
||||||
|
# assign_add not called.
|
||||||
|
self.assertEqual(self.evaluate(layer.v), 1.)
|
||||||
|
|
||||||
|
|
||||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -505,11 +505,6 @@ def _in_place_subclassed_model_reset(model):
|
|||||||
setattr(model, name, fresh_layer)
|
setattr(model, name, fresh_layer)
|
||||||
model._layers.append(fresh_layer)
|
model._layers.append(fresh_layer)
|
||||||
|
|
||||||
# The base Layer __setattr__ will invalidate its attribute cache when
|
|
||||||
# `._layers` is assigned, but it has no way to know when the underlying list
|
|
||||||
# is mutated so we must explicitly signal the append.
|
|
||||||
model._attribute_sentinel.invalidate_all()
|
|
||||||
|
|
||||||
# Cache original model build attributes (in addition to layers)
|
# Cache original model build attributes (in addition to layers)
|
||||||
if (not hasattr(model, '_original_attributes_cache') or
|
if (not hasattr(model, '_original_attributes_cache') or
|
||||||
model._original_attributes_cache is None):
|
model._original_attributes_cache is None):
|
||||||
|
@ -79,7 +79,7 @@ def use_wrapped_call(layer, call_fn, default_training_value=None,
|
|||||||
# child layers. This causes `.losses` to only return eager losses.
|
# child layers. This causes `.losses` to only return eager losses.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
for i in layer._gather_unique_layers():
|
for i in layer._flatten_layers():
|
||||||
if i is not layer:
|
if i is not layer:
|
||||||
i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
|
i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user