Switch RNN and Wrapper to use standard Layer sub-Layer tracking
The "states" property needs an opt-out so it doesn't catch states which are Variables. And Layer.__setattr__ in general will now ignore @property.setters when tracking, on the assumption that the setter itself will assign to something (or as in this case disable tracking). There's also an odd issue with isinstance(obj, list) that only showed up in the Python 3.4 presubmit. It may be a 3.4-specific issue, since ListWrapper passes isinstance(ListWrapper(), list) in Python 3.6. For now I've used collections.Sequence which is slightly better anyway. PiperOrigin-RevId: 236189545
This commit is contained in:
parent
f929a5dbd9
commit
e0f8b4cea2
@ -1776,7 +1776,9 @@ class Layer(trackable.Trackable):
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (not getattr(self, '_setattr_tracking', True) or
|
||||
getattr(self, '_is_graph_network', False)):
|
||||
getattr(self, '_is_graph_network', False) or
|
||||
# Exclude @property.setters from tracking
|
||||
hasattr(self.__class__, name)):
|
||||
super(Layer, self).__setattr__(name, value)
|
||||
return
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
@ -191,74 +192,6 @@ class StackedRNNCells(Layer):
|
||||
deserialize_layer(cell_config, custom_objects=custom_objects))
|
||||
return cls(cells, **config)
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
if not self.trainable:
|
||||
return []
|
||||
weights = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
weights += cell.trainable_weights
|
||||
return weights
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
weights = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
weights += cell.non_trainable_weights
|
||||
if not self.trainable:
|
||||
trainable_weights = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
trainable_weights += cell.trainable_weights
|
||||
return trainable_weights + weights
|
||||
return weights
|
||||
|
||||
def get_weights(self):
|
||||
"""Retrieves the weights of the model.
|
||||
|
||||
Returns:
|
||||
A flat list of Numpy arrays.
|
||||
"""
|
||||
weights = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
weights += cell.weights
|
||||
return K.batch_get_value(weights)
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the weights of the model.
|
||||
|
||||
Arguments:
|
||||
weights: A list of Numpy arrays with shapes and types matching
|
||||
the output of `model.get_weights()`.
|
||||
"""
|
||||
tuples = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
num_param = len(cell.weights)
|
||||
weights = weights[:num_param]
|
||||
for sw, w in zip(cell.weights, weights):
|
||||
tuples.append((sw, w))
|
||||
weights = weights[num_param:]
|
||||
K.batch_set_value(tuples)
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
losses = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
losses += cell.losses
|
||||
return losses + self._losses
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
updates = []
|
||||
for cell in self.cells:
|
||||
if isinstance(cell, Layer):
|
||||
updates += cell.updates
|
||||
return updates + self._updates
|
||||
|
||||
|
||||
@keras_export('keras.layers.RNN')
|
||||
@ -455,8 +388,6 @@ class RNN(Layer):
|
||||
```
|
||||
"""
|
||||
|
||||
_setattr_tracking = False
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
return_sequences=False,
|
||||
@ -481,8 +412,6 @@ class RNN(Layer):
|
||||
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
|
||||
super(RNN, self).__init__(**kwargs)
|
||||
self.cell = cell
|
||||
if isinstance(cell, trackable.Trackable):
|
||||
self._track_trackable(self.cell, name='cell')
|
||||
self.return_sequences = return_sequences
|
||||
self.return_state = return_state
|
||||
self.go_backwards = go_backwards
|
||||
@ -508,6 +437,9 @@ class RNN(Layer):
|
||||
return self._states
|
||||
|
||||
@states.setter
|
||||
# Automatic tracking catches "self._states" which adds an extra weight and
|
||||
# breaks HDF5 checkpoints.
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def states(self, states):
|
||||
self._states = states
|
||||
|
||||
@ -871,7 +803,8 @@ class RNN(Layer):
|
||||
# input shape: `(samples, time (padded with zeros), input_dim)`
|
||||
# note that the .build() method of subclasses MUST define
|
||||
# self.input_spec and self.state_spec with complete input shapes.
|
||||
if isinstance(inputs, list):
|
||||
if (isinstance(inputs, collections.Sequence)
|
||||
and not isinstance(inputs, tuple)):
|
||||
# get initial_state from full input spec
|
||||
# as they could be copied to multiple GPU.
|
||||
if self._num_constants is None:
|
||||
@ -988,36 +921,6 @@ class RNN(Layer):
|
||||
layer._num_constants = num_constants
|
||||
return layer
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
if not self.trainable:
|
||||
return []
|
||||
if isinstance(self.cell, Layer):
|
||||
return self.cell.trainable_weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
if isinstance(self.cell, Layer):
|
||||
if not self.trainable:
|
||||
return self.cell.weights
|
||||
return self.cell.non_trainable_weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
layer_losses = super(RNN, self).losses
|
||||
if isinstance(self.cell, Layer):
|
||||
return self.cell.losses + layer_losses
|
||||
return layer_losses
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
updates = []
|
||||
if isinstance(self.cell, Layer):
|
||||
updates += self.cell.updates
|
||||
return updates + self._updates
|
||||
|
||||
|
||||
@keras_export('keras.layers.AbstractRNNCell')
|
||||
class AbstractRNNCell(Layer):
|
||||
@ -2262,8 +2165,6 @@ class UnifiedGRU(DropoutRNNCellMixin, GRU):
|
||||
call of the cell.
|
||||
"""
|
||||
|
||||
_setattr_tracking = False # TODO(allenl): Figure out why this is needed.
|
||||
|
||||
def __init__(self,
|
||||
units,
|
||||
activation='tanh',
|
||||
@ -3666,3 +3567,4 @@ def _runtime(runtime_name):
|
||||
with ops.device('/cpu:0'):
|
||||
return constant_op.constant(
|
||||
runtime_name, dtype=dtypes.string, name='runtime')
|
||||
|
||||
|
@ -43,6 +43,7 @@ from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# Global config for grappler setting that is used for graph mode test.
|
||||
@ -184,6 +185,7 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
|
||||
|
||||
layer = keras.layers.UnifiedLSTM(units, stateful=True)
|
||||
layer.build((num_samples, timesteps, embedding_dim))
|
||||
initial_weight_count = len(layer.weights)
|
||||
layer.reset_states()
|
||||
assert len(layer.states) == num_states
|
||||
assert layer.states[0] is not None
|
||||
@ -205,6 +207,12 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
layer.reset_states([1] * (len(layer.states) + 1))
|
||||
|
||||
self.assertEqual(initial_weight_count, len(layer.weights))
|
||||
# Variables in "states" shouldn't show up in .weights
|
||||
layer.states = nest.map_structure(variables.Variable, values)
|
||||
layer.reset_states()
|
||||
self.assertEqual(initial_weight_count, len(layer.weights))
|
||||
|
||||
def test_specify_state_with_masking(self):
|
||||
num_states = 2
|
||||
timesteps = 3
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -47,7 +46,6 @@ class Wrapper(Layer):
|
||||
layer: The layer to be wrapped.
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layer, **kwargs):
|
||||
assert isinstance(layer, Layer)
|
||||
self.layer = layer
|
||||
@ -67,36 +65,6 @@ class Wrapper(Layer):
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self.layer.trainable
|
||||
|
||||
@trainable.setter
|
||||
def trainable(self, value):
|
||||
self.layer.trainable = value
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
return self.layer.trainable_weights
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
return self.layer.non_trainable_weights
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
return self.layer.updates + self._updates
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
return self.layer.losses + self._losses
|
||||
|
||||
def get_weights(self):
|
||||
return self.layer.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.layer.set_weights(weights)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'layer': {
|
||||
@ -180,7 +148,6 @@ class TimeDistributed(Wrapper):
|
||||
'`Layer` instance. You passed: {input}'.format(input=layer))
|
||||
super(TimeDistributed, self).__init__(layer, **kwargs)
|
||||
self.supports_masking = True
|
||||
self._track_trackable(layer, name='layer')
|
||||
|
||||
# It is safe to use the fast, reshape-based approach with all of our
|
||||
# built-in Layers.
|
||||
@ -407,7 +374,6 @@ class Bidirectional(Wrapper):
|
||||
```
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
|
||||
if not isinstance(layer, Layer):
|
||||
raise ValueError(
|
||||
@ -438,28 +404,12 @@ class Bidirectional(Wrapper):
|
||||
self.supports_masking = True
|
||||
self._trainable = True
|
||||
self._num_constants = None
|
||||
# We don't want to track `layer` since we're already tracking the two copies
|
||||
# of it we actually run.
|
||||
self._setattr_tracking = False
|
||||
super(Bidirectional, self).__init__(layer, **kwargs)
|
||||
self._setattr_tracking = True
|
||||
self.input_spec = layer.input_spec
|
||||
self._track_trackable(self.forward_layer, name='forward_layer')
|
||||
self._track_trackable(self.backward_layer, name='backward_layer')
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._trainable
|
||||
|
||||
@trainable.setter
|
||||
def trainable(self, value):
|
||||
self._trainable = value
|
||||
self.forward_layer.trainable = value
|
||||
self.backward_layer.trainable = value
|
||||
|
||||
def get_weights(self):
|
||||
return self.forward_layer.get_weights() + self.backward_layer.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
nw = len(weights)
|
||||
self.forward_layer.set_weights(weights[:nw // 2])
|
||||
self.backward_layer.set_weights(weights[nw // 2:])
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def compute_output_shape(self, input_shape):
|
||||
@ -653,32 +603,6 @@ class Bidirectional(Wrapper):
|
||||
return [output_mask] + state_mask * 2
|
||||
return output_mask
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
if hasattr(self.forward_layer, 'trainable_weights'):
|
||||
return (self.forward_layer.trainable_weights +
|
||||
self.backward_layer.trainable_weights)
|
||||
return []
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
if hasattr(self.forward_layer, 'non_trainable_weights'):
|
||||
return (self.forward_layer.non_trainable_weights +
|
||||
self.backward_layer.non_trainable_weights)
|
||||
return []
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
if hasattr(self.forward_layer, 'updates'):
|
||||
return self.forward_layer.updates + self.backward_layer.updates
|
||||
return []
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
if hasattr(self.forward_layer, 'losses'):
|
||||
return self.forward_layer.losses + self.backward_layer.losses
|
||||
return []
|
||||
|
||||
@property
|
||||
def constraints(self):
|
||||
constraints = {}
|
||||
|
@ -90,7 +90,8 @@ class TimeDistributedTest(test.TestCase):
|
||||
|
||||
# check whether the model variables are present in the
|
||||
# trackable list of objects
|
||||
checkpointed_objects = set(trackable_util.list_objects(model))
|
||||
checkpointed_objects = object_identity.ObjectIdentitySet(
|
||||
trackable_util.list_objects(model))
|
||||
for v in model.variables:
|
||||
self.assertIn(v, checkpointed_objects)
|
||||
|
||||
|
@ -69,10 +69,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -65,10 +65,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -64,10 +64,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -69,10 +69,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -65,10 +65,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -64,10 +64,6 @@ tf_class {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user