Support TF Modules inside Keras Layers and Models.

With this change, it is now possible to mix-and-match tf.keras.Layers and
tf.Modules inside a tf.keras.Model and everything will be tracked properly.

- Variables in tf.Modules that are set as attributes of custom Layers and
  Models now show up properly in properties such as Layer.trainable_variables
  and Model.trainable_variables.
- tf.Modules do not show up in Model.layers. Instead, a new method
  Layer._flatten_modules is added that iterates over tf.Modules and Layers in
  the order that Keras expects. The existing method Layer.submodules (inherited
  from tf.Module) can still be used to iterate over tf.Modules and Layer with the
  tf.Module ordering. Layer._flatten_layers is built on top of
  Layer._flatten_modules
- Layer._layers is renamed to Layer._self_tracked_trackables to avoid naming
  conflicts with user-defined attributes (and to reflect that this attr
  contains Layers, Modules, and TrackableDataStructures)
- A new property is added to tf.Module to enable this, namely
  tf.Module.non_trainable_variables

PiperOrigin-RevId: 339917644
Change-Id: I96a7302745280a6261de8c4295c5cbf5f4d7dd5c
This commit is contained in:
Thomas O'Malley 2020-10-30 12:22:31 -07:00 committed by TensorFlower Gardener
parent 9c3c3855bd
commit d266494953
61 changed files with 468 additions and 174 deletions

View File

@ -386,10 +386,14 @@ This release contains contributions from many people at Google, as well as:
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
* AutoGraph now allows creating new symbols inside a TensorFlow loop, if
the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise.
* Variables contained in `tf.Module`s that are set as attributes of
custom Keras `Layer`s and `Model`s are now tracked in
the properties `layer.trainable_variables` and
`layer.non_trainable_variables`.
Example:

View File

@ -397,10 +397,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._autocast = kwargs.get('autocast',
base_layer_utils.v2_dtype_behavior_enabled())
# Dependencies tracked via attribute assignment.
# All layers in order of horizontal graph traversal.
# Entries are unique. For models includes input and output layers.
self._maybe_create_attribute('_layers', [])
# Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s.
# Ordered by when the object was assigned as an attr.
# Entries are unique.
self._maybe_create_attribute('_self_tracked_trackables', [])
# These lists will be filled via successive calls
# to self._add_inbound_node().
@ -1351,14 +1351,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Trainable weights are updated via gradient descent during training.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of trainable variables.
"""
if self.trainable:
children_weights = self._gather_children_attribute('trainable_weights')
children_weights = self._gather_children_attribute('trainable_variables')
return self._dedup_weights(self._trainable_weights + children_weights)
else:
return []
@ -1370,18 +1367,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Non-trainable weights are *not* updated during training. They are expected
to be updated manually in `call()`.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of non-trainable variables.
"""
if self.trainable:
children_weights = self._gather_children_attribute(
'non_trainable_weights')
'non_trainable_variables')
non_trainable_weights = self._non_trainable_weights + children_weights
else:
children_weights = self._gather_children_attribute('weights')
children_weights = self._gather_children_attribute('variables')
non_trainable_weights = (
self._trainable_weights + self._non_trainable_weights +
children_weights)
@ -1391,9 +1385,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def weights(self):
"""Returns the list of all layer variables/weights.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of variables.
"""
@ -1609,7 +1600,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def _clear_losses(self):
"""Used every step in eager to reset losses."""
# Set to thread local directly to avoid Layer.__setattr__ overhead.
if not getattr(self, '_layers', None): # Fast path for single Layer.
if not getattr(self, '_self_tracked_trackables',
None): # Fast path for single Layer.
self._thread_local._eager_losses = []
else:
for layer in self._flatten_layers():
@ -2779,7 +2771,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
default_value: Object, the default value of the attribute.
"""
if not hasattr(self, name):
super(Layer, self).__setattr__(name, default_value)
self.__setattr__(name, default_value)
def __delattr__(self, name):
# For any super.__delattr__() call, we will directly use the implementation
@ -2813,8 +2805,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if (isinstance(existing_value, Layer)
or base_layer_utils.has_weights(existing_value)):
super(tracking.AutoTrackable, self).__setattr__(
'_layers',
[l for l in self._layers if l is not existing_value])
'_self_tracked_trackables',
[l for l in self._self_tracked_trackables if l is not existing_value])
if isinstance(existing_value, tf_variables.Variable):
super(tracking.AutoTrackable, self).__setattr__(
'_trainable_weights',
@ -2837,7 +2829,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
'different name.').format(name))
return
# Keep track of trackable objects, for the needs of `Network.save_weights`.
# Wraps data structures in `Trackable`, unwraps `NoDependency` objects.
value = data_structures.sticky_attribute_assignment(
trackable=self, value=value, name=name)
@ -2856,16 +2848,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
self._metrics.append(val)
# TODO(scottzhu): Need to track Module object as well for weight tracking.
# Be careful about metric if it becomes a Module in future.
# Append value to self._layers if relevant
# Append value to self._self_tracked_trackables if relevant
if (getattr(self, '_auto_track_sub_layers', True) and
(isinstance(value, Layer) or base_layer_utils.has_weights(value))):
self._maybe_create_attribute('_layers', [])
(isinstance(value, module.Module) or
base_layer_utils.has_weights(value))):
self._maybe_create_attribute('_self_tracked_trackables', [])
# We need to check object identity to avoid de-duplicating empty
# container types which compare equal.
if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if not any((layer is value for layer in self._self_tracked_trackables)):
self._self_tracked_trackables.append(value)
if hasattr(value, '_use_resource_variables'):
# Legacy layers (V1 tf.layers) must always use
# resource variables.
@ -2903,45 +2894,59 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def _gather_children_attribute(self, attribute):
assert attribute in {
'weights', 'trainable_weights', 'non_trainable_weights'
'variables', 'trainable_variables', 'non_trainable_variables'
}
if hasattr(self, '_layers'):
nested_layers = layer_utils.filter_empty_layer_containers(
self._layers)
if hasattr(self, '_self_tracked_trackables'):
nested_layers = self._flatten_modules(include_self=False, recursive=False)
return list(
itertools.chain.from_iterable(
getattr(layer, attribute) for layer in nested_layers))
return []
def _flatten_layers(self, recursive=True, include_self=True):
for m in self._flatten_modules(
recursive=recursive, include_self=include_self):
if isinstance(m, Layer):
yield m
def _flatten_modules(self, recursive=True, include_self=True):
"""Flattens `tf.Module` instances (excluding `Metrics`).
Arguments:
recursive: Whether to recursively flatten through submodules.
include_self: Whether to include this `Layer` instance.
Yields:
`tf.Module` instance tracked by this `Layer`.
"""
if include_self:
yield self
# Only instantiate set and deque if needed.
layers_or_containers = getattr(self, '_layers', None)
if layers_or_containers:
trackables = getattr(self, '_self_tracked_trackables', None)
if trackables:
seen_object_ids = set()
deque = collections.deque(layers_or_containers)
deque = collections.deque(trackables)
while deque:
layer_or_container = deque.popleft()
layer_or_container_id = id(layer_or_container)
if layer_or_container_id in seen_object_ids:
trackable_obj = deque.popleft()
trackable_id = id(trackable_obj)
if trackable_id in seen_object_ids:
continue
seen_object_ids.add(layer_or_container_id)
seen_object_ids.add(trackable_id)
if (isinstance(layer_or_container, Layer) and
not isinstance(layer_or_container, metrics_mod.Metric)):
yield layer_or_container
# Metrics are not considered part of the Layer's topology.
if (isinstance(trackable_obj, module.Module) and
not isinstance(trackable_obj, metrics_mod.Metric)):
yield trackable_obj
# Introspect recursively through sublayers.
if recursive:
sublayers = getattr(layer_or_container, '_layers', None)
if sublayers:
deque.extendleft(reversed(sublayers))
elif isinstance(layer_or_container,
data_structures.TrackableDataStructure):
subtrackables = getattr(trackable_obj, '_self_tracked_trackables',
None)
if subtrackables:
deque.extendleft(reversed(subtrackables))
elif isinstance(trackable_obj, data_structures.TrackableDataStructure):
# Data structures are introspected even with `recursive=False`.
tracked_values = layer_or_container._values
tracked_values = trackable_obj._values
if tracked_values:
deque.extendleft(reversed(tracked_values))

View File

@ -48,6 +48,7 @@ from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.legacy_tf_layers import core as legacy_core
from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@ -840,6 +841,58 @@ class BaseLayerTest(keras_parameterized.TestCase):
layer = MyLayer(activity_regularizer='l2')
self.assertIsInstance(layer.activity_regularizer, regularizers.L2)
def test_tf_module_tracking(self):
class MyModule(module.Module):
def __init__(self):
super(MyModule, self).__init__()
self.v1 = variables.Variable(1., trainable=True, name='v1')
self.v2 = variables.Variable(2., trainable=False, name='v2')
def __call__(self, x):
return x * self.v1 * self.v2
class MyLayer(base_layer.Layer):
def __init__(self, **kwargs):
super(MyLayer, self).__init__(self, **kwargs)
self.my_modules = {}
self.my_modules['a'] = MyModule()
def call(self, x):
return self.my_modules['a'](x)
layer = MyLayer()
self.assertLen(layer.variables, 2)
self.assertLen(layer.trainable_variables, 1)
self.assertLen(layer.non_trainable_variables, 1)
layer.trainable = False
self.assertLen(layer.variables, 2)
self.assertLen(layer.trainable_variables, 0)
self.assertLen(layer.non_trainable_variables, 2)
class MyModel(training_lib.Model):
def __init__(self):
super(MyModel, self).__init__()
self.my_modules = []
self.my_modules.append(MyModule())
def call(self, x):
return self.my_modules[0](x)
model = MyModel()
self.assertLen(model.variables, 2)
self.assertLen(model.trainable_variables, 1)
self.assertLen(model.non_trainable_variables, 1)
model.trainable = False
self.assertLen(model.variables, 2)
self.assertLen(model.trainable_variables, 0)
self.assertLen(model.non_trainable_variables, 2)
class SymbolicSupportTest(keras_parameterized.TestCase):
@ -1078,12 +1131,13 @@ class NestedTrackingTest(test.TestCase):
del l.c
l.d = last_assignment
del l.d
self.assertEqual([last_assignment], l._layers)
sublayers = list(l._flatten_layers(include_self=False, recursive=False))
self.assertEqual([last_assignment], sublayers)
self.assertEqual([], l.trainable_weights)
self.assertEqual([], l.non_trainable_weights)
self.assertEqual([], l.weights)
del l.a
self.assertEqual([], l._layers)
self.assertEqual([], l._self_tracked_trackables)
def test_assign_op_not_tracked_as_variable(self):

View File

@ -212,7 +212,7 @@ class Layer(base_layer.Layer):
# Dependencies tracked via attribute assignment.
# All layers in order of horizontal graph traversal.
# Entries are unique. For models includes input and output layers.
self._maybe_create_attribute('_layers', [])
self._maybe_create_attribute('_self_tracked_trackables', [])
# These lists will be filled via successive calls
# to self._add_inbound_node().
@ -881,7 +881,7 @@ class Layer(base_layer.Layer):
@trainable.setter
def trainable(self, value):
self._trainable = value
for layer in getattr(self, '_layers', []):
for layer in getattr(self, '_self_tracked_trackables', []):
layer.trainable = value
@property
@ -909,36 +909,6 @@ class Layer(base_layer.Layer):
'Got: {}'.format(v))
self._input_spec = value
@property
def trainable_weights(self):
if self.trainable:
children_weights = self._gather_children_attribute('trainable_weights')
return self._dedup_weights(self._trainable_weights + children_weights)
else:
return []
@property
def non_trainable_weights(self):
if self.trainable:
children_weights = self._gather_children_attribute(
'non_trainable_weights')
non_trainable_weights = self._non_trainable_weights + children_weights
else:
children_weights = self._gather_children_attribute('weights')
non_trainable_weights = (
self._trainable_weights + self._non_trainable_weights +
children_weights)
return self._dedup_weights(non_trainable_weights)
@property
def weights(self):
"""Returns the list of all layer variables/weights.
Returns:
A list of variables.
"""
return self.trainable_weights + self.non_trainable_weights
@property
def updates(self):
collected_updates = []
@ -2137,21 +2107,20 @@ class Layer(base_layer.Layer):
Returns:
A dict mapping all sublayers to their `trainable` value.
"""
layers = layer_utils.filter_empty_layer_containers(self._layers)
# Keep track of each top-level layers' `trainable` as well as the
# state of all of its sublayers.
layers = self._flatten_layers(include_self=False, recursive=False)
trainable_state = {self: self.trainable}
for layer in layers:
trainable_state.update(layer._get_trainable_state())
for l in layers:
trainable_state.update(l._get_trainable_state())
return trainable_state
def _set_trainable_state(self, trainable_state):
"""Set `trainable` state for each sublayer."""
layers = layer_utils.filter_empty_layer_containers(self._layers)
if self in trainable_state:
self.trainable = trainable_state[self]
for layer in layers:
layer._set_trainable_state(trainable_state)
layers = self._flatten_layers(include_self=False, recursive=False)
for l in layers:
if l in trainable_state:
l._set_trainable_state(trainable_state)
@property
def _obj_reference_counts(self):
@ -2175,7 +2144,7 @@ class Layer(base_layer.Layer):
default_value: Object, the default value of the attribute.
"""
if not hasattr(self, name):
super(Layer, self).__setattr__(name, default_value)
self.__setattr__(name, default_value)
def __delattr__(self, name):
# For any super.__delattr__() call, we will directly use the implementation
@ -2209,8 +2178,8 @@ class Layer(base_layer.Layer):
if (isinstance(existing_value, Layer)
or base_layer_utils.has_weights(existing_value)):
super(tracking.AutoTrackable, self).__setattr__(
'_layers',
[l for l in self._layers if l is not existing_value])
'_self_tracked_trackables',
[l for l in self._self_tracked_trackables if l is not existing_value])
if isinstance(existing_value, tf_variables.Variable):
super(tracking.AutoTrackable, self).__setattr__(
'_trainable_weights',
@ -2258,11 +2227,11 @@ class Layer(base_layer.Layer):
# Append value to self._layers if relevant
if (getattr(self, '_auto_track_sub_layers', True) and
(isinstance(value, Layer) or base_layer_utils.has_weights(value))):
self._maybe_create_attribute('_layers', [])
self._maybe_create_attribute('_self_tracked_trackables', [])
# We need to check object identity to avoid de-duplicating empty
# container types which compare equal.
if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if not any((layer is value for layer in self._self_tracked_trackables)):
self._self_tracked_trackables.append(value)
if hasattr(value, '_use_resource_variables'):
# Legacy layers (V1 tf.layers) must always use
# resource variables.
@ -2298,18 +2267,6 @@ class Layer(base_layer.Layer):
# at __delattr__.
super(tracking.AutoTrackable, self).__setattr__(name, value)
def _gather_children_attribute(self, attribute):
assert attribute in {
'weights', 'trainable_weights', 'non_trainable_weights'
}
if hasattr(self, '_layers'):
nested_layers = layer_utils.filter_empty_layer_containers(
self._layers)
return list(
itertools.chain.from_iterable(
getattr(layer, attribute) for layer in nested_layers))
return []
# This is a hack so that the is_layer (within
# training/trackable/layer_utils.py) check doesn't get the weights attr.
# TODO(b/110718070): Remove when fixed.

View File

@ -129,7 +129,10 @@ class TestDeferredSequential(keras_parameterized.TestCase):
path = os.path.join(self.get_temp_dir(), 'model_path')
model.save(path)
new_model = keras.models.load_model(path)
for layer1, layer2 in zip(model._layers, new_model._layers):
model_layers = model._flatten_layers(include_self=True, recursive=False)
new_model_layers = new_model._flatten_layers(
include_self=True, recursive=False)
for layer1, layer2 in zip(model_layers, new_model_layers):
self.assertEqual(layer1.name, layer2.name)
for w1, w2 in zip(layer1.weights, layer2.weights):
self.assertAllClose(w1, w2)
@ -144,7 +147,10 @@ class TestDeferredSequential(keras_parameterized.TestCase):
path = os.path.join(self.get_temp_dir(), 'model_path.h5')
model.save(path)
new_model = keras.models.load_model(path)
for layer1, layer2 in zip(model._layers, new_model._layers):
model_layers = model._flatten_layers(include_self=True, recursive=False)
new_model_layers = new_model._flatten_layers(
include_self=True, recursive=False)
for layer1, layer2 in zip(model_layers, new_model_layers):
self.assertEqual(layer1.name, layer2.name)
for w1, w2 in zip(layer1.weights, layer2.weights):
self.assertAllClose(w1, w2)

View File

@ -204,9 +204,9 @@ class Functional(training_lib.Model):
self.inputs, self.outputs)
self._network_nodes = nodes
self._nodes_by_depth = nodes_by_depth
self._layers = layers
self._self_tracked_trackables = layers
self._layer_call_argspecs = {}
for layer in self._layers:
for layer in self._self_tracked_trackables:
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
# Build self.input_names and self.output_names.
@ -797,11 +797,11 @@ class Functional(training_lib.Model):
self._nodes_by_depth[depth].append(node)
# Insert layers and update other layer attrs.
layer_set = set(self._layers)
layer_set = set(self._self_tracked_trackables)
deferred_layers = []
for layer in layers:
if layer not in layer_set:
self._layers.append(layer)
self._self_tracked_trackables.append(layer)
deferred_layers.append(layer)
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
layer_set.add(layer)
@ -1089,7 +1089,8 @@ def _should_skip_first_node(layer):
# the network config.
return (isinstance(layer, Functional) and
# Filter out Sequential models without an input shape.
isinstance(layer._layers[0], input_layer_module.InputLayer))
isinstance(layer._self_tracked_trackables[0],
input_layer_module.InputLayer))
def connect_ancillary_layers(model, created_layers):

View File

@ -192,7 +192,8 @@ class Sequential(functional.Functional):
self.built = False
set_inputs = False
if not self._layers:
self._maybe_create_attribute('_self_tracked_trackables', [])
if not self._self_tracked_trackables:
if isinstance(layer, input_layer.InputLayer):
# Case where the user passes an Input or InputLayer layer via `add`.
set_inputs = True
@ -230,7 +231,7 @@ class Sequential(functional.Functional):
self._init_graph_network(self.inputs, self.outputs)
self._graph_initialized = True
else:
self._layers.append(layer)
self._self_tracked_trackables.append(layer)
self._handle_deferred_layer_dependencies([layer])
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
@ -245,7 +246,7 @@ class Sequential(functional.Functional):
if not self.layers:
raise TypeError('There are no layers in the model.')
layer = self._layers.pop()
layer = self._self_tracked_trackables.pop()
self._layer_call_argspecs.pop(layer)
if not self.layers:
self.outputs = None
@ -466,8 +467,8 @@ class Sequential(functional.Functional):
layer_configs = []
for layer in super(Sequential, self).layers:
# `super().layers` include the InputLayer if available (it is filtered out
# of `self.layers`). Note that `self._layers` is managed by the
# tracking infrastructure and should not be used.
# of `self.layers`). Note that `self._self_tracked_trackables` is managed
# by the tracking infrastructure and should not be used.
layer_configs.append(generic_utils.serialize_keras_object(layer))
config = {
'name': self.name,

View File

@ -412,20 +412,30 @@ class TestSequential(keras_parameterized.TestCase):
"""Test that Sequential only tracks layers added in init or `.add`."""
layer = keras.layers.Dense(1)
model = keras.Sequential([layer])
self.assertEqual(model._layers[-1], layer)
self.assertEqual(
list(model._flatten_layers(include_self=False, recursive=False))[-1],
layer)
model.a = [keras.layers.Dense(3)] # should not be added to the layers list.
self.assertEqual(model._layers[-1], layer)
self.assertEqual(
list(model._flatten_layers(include_self=False, recursive=False))[-1],
layer)
layer2 = keras.layers.Dense(2)
model.add(layer2)
self.assertEqual(model._layers[-1], layer2)
self.assertEqual(
list(model._flatten_layers(include_self=False, recursive=False))[-1],
layer2)
model.a = [keras.layers.Dense(3)] # should not be added to the layers list.
self.assertEqual(model._layers[-1], layer2)
self.assertEqual(
list(model._flatten_layers(include_self=False, recursive=False))[-1],
layer2)
model.pop()
self.assertEqual(model._layers[-1], layer)
self.assertEqual(
list(model._flatten_layers(include_self=False, recursive=False))[-1],
layer)
def test_config_preserves_input_layer(self):
model = keras.Sequential([
@ -436,8 +446,10 @@ class TestSequential(keras_parameterized.TestCase):
config = model.get_config()
new_model = keras.Sequential.from_config(config)
self.assertTrue(new_model.built)
self.assertEqual(new_model._layers[0].dtype, 'int32')
self.assertEqual(new_model._layers[0].name, 'my_embedding_input')
layers = list(
new_model._flatten_layers(include_self=False, recursive=False))
self.assertEqual(layers[0].dtype, 'int32')
self.assertEqual(layers[0].name, 'my_embedding_input')
def test_name_unicity(self):
model = keras.Sequential()

View File

@ -78,7 +78,6 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@ -1922,21 +1921,35 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
@property
def trainable_weights(self):
self._assert_weights_created()
return self._dedup_weights(
trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights))
if not self._trainable:
return []
trainable_variables = []
for trackable_obj in self._self_tracked_trackables:
trainable_variables += trackable_obj.trainable_variables
trainable_variables += self._trainable_weights
return self._dedup_weights(trainable_variables)
@property
def non_trainable_weights(self):
self._assert_weights_created()
return self._dedup_weights(
trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights +
self._trainable_weights))
non_trainable_variables = []
for trackable_obj in self._self_tracked_trackables:
non_trainable_variables += trackable_obj.non_trainable_variables
if not self._trainable:
# Return order is all trainable vars, then all non-trainable vars.
trainable_variables = []
for trackable_obj in self._self_tracked_trackables:
trainable_variables += trackable_obj.trainable_variables
non_trainable_variables = (
trainable_variables + self._trainable_weights +
non_trainable_variables + self._non_trainable_weights)
else:
non_trainable_variables = (
non_trainable_variables + self._non_trainable_weights)
return self._dedup_weights(non_trainable_variables)
def get_weights(self):
"""Retrieves the weights of the model.
@ -2349,8 +2362,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
"""Returns the undeduplicated list of all layer variables/weights."""
self._assert_weights_created()
weights = []
for layer in self._layers:
weights += layer.weights
for layer in self._self_tracked_trackables:
weights += layer.variables
weights += (self._trainable_weights + self._non_trainable_weights)
return weights

View File

@ -846,7 +846,8 @@ class TrainingTest(keras_parameterized.TestCase):
return self.layer2(self.layer1(inputs))
l = LayerWithWeightSharedLayers()
self.assertEqual(l._layers, [l.layer1, l.layer2])
layers = list(l._flatten_layers(include_self=False, recursive=False))
self.assertEqual(layers, [l.layer1, l.layer2])
self.assertEqual(l.variables,
[l.layer1.trainable_var, l.layer1.non_trainable_var])
self.assertEqual(l.trainable_variables, [l.layer1.trainable_var])

View File

@ -504,7 +504,9 @@ class Model(training_lib.Model):
return super(Model, self).metrics
metrics += self._compile_metric_functions
metrics.extend(self._metrics)
metrics.extend(_get_metrics_from_layers(self._layers))
metrics.extend(
_get_metrics_from_layers(
list(self._flatten_layers(include_self=False, recursive=False))))
return metrics
@property
@ -1717,7 +1719,7 @@ class Model(training_lib.Model):
# Avoids the override in Sequential.layers which filters Input layers.
# (Which are often the very layers that we're after.)
layers = layer_utils.filter_empty_layer_containers(self._layers)
layers = self._flatten_layers(include_self=False, recursive=False)
first_layer = next(layers, None)
if first_layer:
# The per-replica static batch size.

View File

@ -318,10 +318,10 @@ def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
layers = [] # Layers needed to compute the model's outputs.
layer_map = {}
# Use model._layers to ensure that all layers are cloned. The model's layers
# Ensure that all layers are cloned. The model's layers
# property will exclude the initial InputLayer (if it exists) in the model,
# resulting in a different Sequential model structure.
for layer in model._layers:
for layer in model._flatten_layers(include_self=False, recursive=False):
if isinstance(layer, InputLayer) and input_tensors is not None:
# If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer.
@ -460,9 +460,8 @@ def _in_place_subclassed_model_reset(model):
# Retrieve all layers tracked by the model as well as their attribute names
attributes_cache = {}
for name in dir(model):
# Skip the check of methods in tf.Module since they basically
# recursively query all the other attributes within same module.
if name == 'submodules':
# Skip attrs that track other trackables.
if name == 'submodules' or name == '_self_tracked_trackables':
continue
try:
@ -489,10 +488,11 @@ def _in_place_subclassed_model_reset(model):
# Replace layers on the model with fresh layers
layers_to_names = {value: key for key, value in attributes_cache.items()}
original_layers = model._layers[:]
original_layers = list(
model._flatten_layers(include_self=False, recursive=False))
setattr_tracking = model._setattr_tracking
model._setattr_tracking = False
model._layers = []
model._self_tracked_trackables = []
for layer in original_layers: # We preserve layer order.
config = layer.get_config()
# This will not work for nested subclassed models used as layers.
@ -505,7 +505,7 @@ def _in_place_subclassed_model_reset(model):
fresh_layer = layer.__class__.from_config(config)
name = layers_to_names[layer]
setattr(model, name, fresh_layer)
model._layers.append(fresh_layer)
model._self_tracked_trackables.append(fresh_layer)
# Cache original model build attributes (in addition to layers)
if (not hasattr(model, '_original_attributes_cache') or
@ -576,11 +576,11 @@ def in_place_subclassed_model_state_restoration(model):
# when they're constructed.
setattr_tracking = model._setattr_tracking
model._setattr_tracking = False
model._layers = []
model._self_tracked_trackables = []
for name, value in model._original_attributes_cache.items():
setattr(model, name, value)
if isinstance(value, Layer):
model._layers.append(value)
model._self_tracked_trackables.append(value)
model._original_attributes_cache = None
model._setattr_tracking = setattr_tracking
else:

View File

@ -111,16 +111,20 @@ class TestModelCloning(keras_parameterized.TestCase):
model = models.Sequential(_get_layers(input_shape, add_input_layer))
# Sanity check
self.assertEqual(
isinstance(model._layers[0], keras.layers.InputLayer),
add_input_layer)
isinstance(
list(model._flatten_layers(include_self=False, recursive=False))[0],
keras.layers.InputLayer), add_input_layer)
self.assertEqual(model._is_graph_network, add_input_layer)
# With placeholder creation -- clone model should have an InputLayer
# if the original model has one.
new_model = clone_fn(model)
self.assertEqual(
isinstance(new_model._layers[0], keras.layers.InputLayer),
add_input_layer)
isinstance(
list(
new_model._flatten_layers(include_self=False,
recursive=False))[0],
keras.layers.InputLayer), add_input_layer)
self.assertEqual(new_model._is_graph_network, model._is_graph_network)
if input_shape and not ops.executing_eagerly_outside_functions():
# update ops from batch norm needs to be included
@ -129,7 +133,9 @@ class TestModelCloning(keras_parameterized.TestCase):
# On top of new tensor -- clone model should always have an InputLayer.
input_a = keras.Input(shape=(4,))
new_model = clone_fn(model, input_tensors=input_a)
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
self.assertIsInstance(
list(new_model._flatten_layers(include_self=False, recursive=False))[0],
keras.layers.InputLayer)
self.assertTrue(new_model._is_graph_network)
# On top of new, non-Keras tensor -- clone model should always have an
@ -139,7 +145,10 @@ class TestModelCloning(keras_parameterized.TestCase):
# saying they should not be used with EagerTensors
input_a = keras.backend.variable(val_a)
new_model = clone_fn(model, input_tensors=input_a)
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
self.assertIsInstance(
list(new_model._flatten_layers(include_self=False,
recursive=False))[0],
keras.layers.InputLayer)
self.assertTrue(new_model._is_graph_network)
@keras_parameterized.run_all_keras_modes

View File

@ -328,7 +328,9 @@ class TestModelRevive(ReviveTestBase):
])
model.save(self.path, save_format='tf')
revived = keras_load.load(self.path)
self.assertEqual(dtypes.string, revived._layers[0].dtype)
revived_layers = list(
revived._flatten_layers(include_self=False, recursive=False))
self.assertEqual(dtypes.string, revived_layers[0].dtype)
@parameterized.named_parameters(
('default_config', CustomNetworkDefaultConfig),

View File

@ -25,7 +25,6 @@ from tensorflow.python.eager import context
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils.generic_utils import LazyLoader
@ -117,10 +116,11 @@ def layer_uses_training_bool(layer):
def list_all_layers(obj):
if isinstance(obj, training_lib.Model):
# Handle special case of Sequential, which doesn't return
# the `Input` layer.
return obj.layers
else:
return list(
layer_utils.filter_empty_layer_containers(obj._layers)) # pylint: disable=protected-access
return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access
def list_all_layers_and_sublayers(obj):

View File

@ -175,6 +175,21 @@ class Module(tracking.AutoTrackable):
return tuple(
self._flatten(predicate=_is_trainable_variable, expand_composites=True))
@property
def non_trainable_variables(self):
"""Sequence of non-trainable variables owned by this module and its submodules.
Note: this method uses reflection to find variables on the current instance
and submodules. For performance reasons you may wish to cache the result
of calling this method if you don't expect the return value to change.
Returns:
A sequence of variables for the current module (sorted by attribute
name) followed by variables from all submodules recursively (breadth
first).
"""
return tuple(self._flatten(predicate=_is_non_trainable_variable))
@property
def submodules(self):
"""Sequence of all sub-modules.
@ -310,6 +325,10 @@ def _is_trainable_variable(obj):
return _is_variable(obj) and getattr(obj, "trainable", False)
def _is_non_trainable_variable(obj):
return _is_variable(obj) and not getattr(obj, "trainable", False)
def _is_module(obj):
return isinstance(obj, Module)

View File

@ -35,8 +35,12 @@ from tensorflow.python.ops import variables
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import layer_utils
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.compat import collections_abc
module = lazy_loader.LazyLoader(
"module", globals(), "tensorflow.python.module.module")
class NoDependency(object):
"""Allows attribute assignment to `Trackable` objects with no dependency.
@ -213,17 +217,45 @@ class TrackableDataStructure(base.Trackable):
@property
def trainable_weights(self):
return layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._self_extra_variables)
if not self._self_trainable:
return []
trainable_variables = []
for obj in self._values:
if isinstance(obj, (TrackableDataStructure, module.Module)):
trainable_variables += obj.trainable_variables
trainable_extra_variables = [
v for v in self._self_extra_variables if v.trainable
]
return trainable_variables + trainable_extra_variables
@property
def non_trainable_weights(self):
return layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._self_extra_variables)
trainable_extra_variables = [
v for v in self._self_extra_variables if v.trainable
]
non_trainable_extra_variables = [
v for v in self._self_extra_variables if not v.trainable
]
non_trainable_variables = []
for obj in self._values:
if isinstance(obj, (TrackableDataStructure, module.Module)):
non_trainable_variables += obj.non_trainable_variables
if not self._self_trainable:
# Return order is all trainable vars, then all non-trainable vars.
trainable_variables = []
for obj in self._values:
if isinstance(obj, (TrackableDataStructure, module.Module)):
trainable_variables += obj.trainable_variables
non_trainable_variables = (
trainable_variables + trainable_extra_variables +
non_trainable_variables + non_trainable_extra_variables)
else:
non_trainable_variables = (
non_trainable_variables + non_trainable_extra_variables)
return non_trainable_variables
@property
def weights(self):

View File

@ -12,6 +12,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operator"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -51,6 +51,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operator"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -66,6 +66,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -55,6 +55,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -58,6 +58,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -49,6 +49,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -12,6 +12,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operator"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -51,6 +51,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operator"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "operators"
mtype: "<type \'property\'>"

View File

@ -66,6 +66,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -55,6 +55,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -58,6 +58,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"

View File

@ -49,6 +49,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"