Switch RNN and Wrapper to use standard Layer sub-Layer tracking

The "states" property needs an opt-out so it doesn't catch states which are Variables. And Layer.__setattr__ in general will now ignore @property.setters when tracking, on the assumption that the setter itself will assign to something (or as in this case disable tracking).

There's also an odd issue with isinstance(obj, list) that only showed up in the Python 3.4 presubmit. It may be a 3.4-specific issue, since ListWrapper passes isinstance(ListWrapper(), list) in Python 3.6. For now I've used collections.Sequence which is slightly better anyway.

PiperOrigin-RevId: 236189545
This commit is contained in:
Allen Lavoie 2019-02-28 13:44:13 -08:00 committed by TensorFlower Gardener
parent f929a5dbd9
commit e0f8b4cea2
11 changed files with 24 additions and 211 deletions

View File

@ -1776,7 +1776,9 @@ class Layer(trackable.Trackable):
def __setattr__(self, name, value):
if (not getattr(self, '_setattr_tracking', True) or
getattr(self, '_is_graph_network', False)):
getattr(self, '_is_graph_network', False) or
# Exclude @property.setters from tracking
hasattr(self.__class__, name)):
super(Layer, self).__setattr__(name, value)
return

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import uuid
import numpy as np
@ -191,74 +192,6 @@ class StackedRNNCells(Layer):
deserialize_layer(cell_config, custom_objects=custom_objects))
return cls(cells, **config)
@property
def trainable_weights(self):
if not self.trainable:
return []
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.trainable_weights
return weights
@property
def non_trainable_weights(self):
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.non_trainable_weights
if not self.trainable:
trainable_weights = []
for cell in self.cells:
if isinstance(cell, Layer):
trainable_weights += cell.trainable_weights
return trainable_weights + weights
return weights
def get_weights(self):
"""Retrieves the weights of the model.
Returns:
A flat list of Numpy arrays.
"""
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.weights
return K.batch_get_value(weights)
def set_weights(self, weights):
"""Sets the weights of the model.
Arguments:
weights: A list of Numpy arrays with shapes and types matching
the output of `model.get_weights()`.
"""
tuples = []
for cell in self.cells:
if isinstance(cell, Layer):
num_param = len(cell.weights)
weights = weights[:num_param]
for sw, w in zip(cell.weights, weights):
tuples.append((sw, w))
weights = weights[num_param:]
K.batch_set_value(tuples)
@property
def losses(self):
losses = []
for cell in self.cells:
if isinstance(cell, Layer):
losses += cell.losses
return losses + self._losses
@property
def updates(self):
updates = []
for cell in self.cells:
if isinstance(cell, Layer):
updates += cell.updates
return updates + self._updates
@keras_export('keras.layers.RNN')
@ -455,8 +388,6 @@ class RNN(Layer):
```
"""
_setattr_tracking = False
def __init__(self,
cell,
return_sequences=False,
@ -481,8 +412,6 @@ class RNN(Layer):
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
super(RNN, self).__init__(**kwargs)
self.cell = cell
if isinstance(cell, trackable.Trackable):
self._track_trackable(self.cell, name='cell')
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
@ -508,6 +437,9 @@ class RNN(Layer):
return self._states
@states.setter
# Automatic tracking catches "self._states" which adds an extra weight and
# breaks HDF5 checkpoints.
@trackable.no_automatic_dependency_tracking
def states(self, states):
self._states = states
@ -871,7 +803,8 @@ class RNN(Layer):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
if (isinstance(inputs, collections.Sequence)
and not isinstance(inputs, tuple)):
# get initial_state from full input spec
# as they could be copied to multiple GPU.
if self._num_constants is None:
@ -988,36 +921,6 @@ class RNN(Layer):
layer._num_constants = num_constants
return layer
@property
def trainable_weights(self):
if not self.trainable:
return []
if isinstance(self.cell, Layer):
return self.cell.trainable_weights
return []
@property
def non_trainable_weights(self):
if isinstance(self.cell, Layer):
if not self.trainable:
return self.cell.weights
return self.cell.non_trainable_weights
return []
@property
def losses(self):
layer_losses = super(RNN, self).losses
if isinstance(self.cell, Layer):
return self.cell.losses + layer_losses
return layer_losses
@property
def updates(self):
updates = []
if isinstance(self.cell, Layer):
updates += self.cell.updates
return updates + self._updates
@keras_export('keras.layers.AbstractRNNCell')
class AbstractRNNCell(Layer):
@ -2262,8 +2165,6 @@ class UnifiedGRU(DropoutRNNCellMixin, GRU):
call of the cell.
"""
_setattr_tracking = False # TODO(allenl): Figure out why this is needed.
def __init__(self,
units,
activation='tanh',
@ -3666,3 +3567,4 @@ def _runtime(runtime_name):
with ops.device('/cpu:0'):
return constant_op.constant(
runtime_name, dtype=dtypes.string, name='runtime')

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import nest
# Global config for grappler setting that is used for graph mode test.
@ -184,6 +185,7 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
layer = keras.layers.UnifiedLSTM(units, stateful=True)
layer.build((num_samples, timesteps, embedding_dim))
initial_weight_count = len(layer.weights)
layer.reset_states()
assert len(layer.states) == num_states
assert layer.states[0] is not None
@ -205,6 +207,12 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
with self.assertRaises(ValueError):
layer.reset_states([1] * (len(layer.states) + 1))
self.assertEqual(initial_weight_count, len(layer.weights))
# Variables in "states" shouldn't show up in .weights
layer.states = nest.map_structure(variables.Variable, values)
layer.reset_states()
self.assertEqual(initial_weight_count, len(layer.weights))
def test_specify_state_with_masking(self):
num_states = 2
timesteps = 3

View File

@ -30,7 +30,6 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -47,7 +46,6 @@ class Wrapper(Layer):
layer: The layer to be wrapped.
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, layer, **kwargs):
assert isinstance(layer, Layer)
self.layer = layer
@ -67,36 +65,6 @@ class Wrapper(Layer):
else:
return None
@property
def trainable(self):
return self.layer.trainable
@trainable.setter
def trainable(self, value):
self.layer.trainable = value
@property
def trainable_weights(self):
return self.layer.trainable_weights
@property
def non_trainable_weights(self):
return self.layer.non_trainable_weights
@property
def updates(self):
return self.layer.updates + self._updates
@property
def losses(self):
return self.layer.losses + self._losses
def get_weights(self):
return self.layer.get_weights()
def set_weights(self, weights):
self.layer.set_weights(weights)
def get_config(self):
config = {
'layer': {
@ -180,7 +148,6 @@ class TimeDistributed(Wrapper):
'`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
self._track_trackable(layer, name='layer')
# It is safe to use the fast, reshape-based approach with all of our
# built-in Layers.
@ -407,7 +374,6 @@ class Bidirectional(Wrapper):
```
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
if not isinstance(layer, Layer):
raise ValueError(
@ -438,28 +404,12 @@ class Bidirectional(Wrapper):
self.supports_masking = True
self._trainable = True
self._num_constants = None
# We don't want to track `layer` since we're already tracking the two copies
# of it we actually run.
self._setattr_tracking = False
super(Bidirectional, self).__init__(layer, **kwargs)
self._setattr_tracking = True
self.input_spec = layer.input_spec
self._track_trackable(self.forward_layer, name='forward_layer')
self._track_trackable(self.backward_layer, name='backward_layer')
@property
def trainable(self):
return self._trainable
@trainable.setter
def trainable(self, value):
self._trainable = value
self.forward_layer.trainable = value
self.backward_layer.trainable = value
def get_weights(self):
return self.forward_layer.get_weights() + self.backward_layer.get_weights()
def set_weights(self, weights):
nw = len(weights)
self.forward_layer.set_weights(weights[:nw // 2])
self.backward_layer.set_weights(weights[nw // 2:])
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
@ -653,32 +603,6 @@ class Bidirectional(Wrapper):
return [output_mask] + state_mask * 2
return output_mask
@property
def trainable_weights(self):
if hasattr(self.forward_layer, 'trainable_weights'):
return (self.forward_layer.trainable_weights +
self.backward_layer.trainable_weights)
return []
@property
def non_trainable_weights(self):
if hasattr(self.forward_layer, 'non_trainable_weights'):
return (self.forward_layer.non_trainable_weights +
self.backward_layer.non_trainable_weights)
return []
@property
def updates(self):
if hasattr(self.forward_layer, 'updates'):
return self.forward_layer.updates + self.backward_layer.updates
return []
@property
def losses(self):
if hasattr(self.forward_layer, 'losses'):
return self.forward_layer.losses + self.backward_layer.losses
return []
@property
def constraints(self):
constraints = {}

View File

@ -90,7 +90,8 @@ class TimeDistributedTest(test.TestCase):
# check whether the model variables are present in the
# trackable list of objects
checkpointed_objects = set(trackable_util.list_objects(model))
checkpointed_objects = object_identity.ObjectIdentitySet(
trackable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)

View File

@ -69,10 +69,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"

View File

@ -65,10 +65,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"

View File

@ -64,10 +64,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"

View File

@ -69,10 +69,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"

View File

@ -65,10 +65,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"

View File

@ -64,10 +64,6 @@ tf_class {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"