From e0f8b4cea2486cae671c60cc7645913bea1abf60 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 28 Feb 2019 13:44:13 -0800 Subject: [PATCH] Switch RNN and Wrapper to use standard Layer sub-Layer tracking The "states" property needs an opt-out so it doesn't catch states which are Variables. And Layer.__setattr__ in general will now ignore @property.setters when tracking, on the assumption that the setter itself will assign to something (or as in this case disable tracking). There's also an odd issue with isinstance(obj, list) that only showed up in the Python 3.4 presubmit. It may be a 3.4-specific issue, since ListWrapper passes isinstance(ListWrapper(), list) in Python 3.6. For now I've used collections.Sequence which is slightly better anyway. PiperOrigin-RevId: 236189545 --- tensorflow/python/keras/engine/base_layer.py | 4 +- tensorflow/python/keras/layers/recurrent.py | 112 ++---------------- .../python/keras/layers/unified_lstm_test.py | 8 ++ tensorflow/python/keras/layers/wrappers.py | 84 +------------ .../python/keras/layers/wrappers_test.py | 3 +- ...nsorflow.keras.layers.-bidirectional.pbtxt | 4 - ...rflow.keras.layers.-time-distributed.pbtxt | 4 - .../v1/tensorflow.keras.layers.-wrapper.pbtxt | 4 - ...nsorflow.keras.layers.-bidirectional.pbtxt | 4 - ...rflow.keras.layers.-time-distributed.pbtxt | 4 - .../v2/tensorflow.keras.layers.-wrapper.pbtxt | 4 - 11 files changed, 24 insertions(+), 211 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 404d3023344..13586de6a89 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1776,7 +1776,9 @@ class Layer(trackable.Trackable): def __setattr__(self, name, value): if (not getattr(self, '_setattr_tracking', True) or - getattr(self, '_is_graph_network', False)): + getattr(self, '_is_graph_network', False) or + # Exclude @property.setters from tracking + hasattr(self.__class__, name)): super(Layer, self).__setattr__(name, value) return diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 11ac4494af0..23a22803b8e 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import uuid import numpy as np @@ -191,74 +192,6 @@ class StackedRNNCells(Layer): deserialize_layer(cell_config, custom_objects=custom_objects)) return cls(cells, **config) - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for cell in self.cells: - if isinstance(cell, Layer): - weights += cell.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for cell in self.cells: - if isinstance(cell, Layer): - weights += cell.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for cell in self.cells: - if isinstance(cell, Layer): - trainable_weights += cell.trainable_weights - return trainable_weights + weights - return weights - - def get_weights(self): - """Retrieves the weights of the model. - - Returns: - A flat list of Numpy arrays. - """ - weights = [] - for cell in self.cells: - if isinstance(cell, Layer): - weights += cell.weights - return K.batch_get_value(weights) - - def set_weights(self, weights): - """Sets the weights of the model. - - Arguments: - weights: A list of Numpy arrays with shapes and types matching - the output of `model.get_weights()`. - """ - tuples = [] - for cell in self.cells: - if isinstance(cell, Layer): - num_param = len(cell.weights) - weights = weights[:num_param] - for sw, w in zip(cell.weights, weights): - tuples.append((sw, w)) - weights = weights[num_param:] - K.batch_set_value(tuples) - - @property - def losses(self): - losses = [] - for cell in self.cells: - if isinstance(cell, Layer): - losses += cell.losses - return losses + self._losses - - @property - def updates(self): - updates = [] - for cell in self.cells: - if isinstance(cell, Layer): - updates += cell.updates - return updates + self._updates @keras_export('keras.layers.RNN') @@ -455,8 +388,6 @@ class RNN(Layer): ``` """ - _setattr_tracking = False - def __init__(self, cell, return_sequences=False, @@ -481,8 +412,6 @@ class RNN(Layer): self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) super(RNN, self).__init__(**kwargs) self.cell = cell - if isinstance(cell, trackable.Trackable): - self._track_trackable(self.cell, name='cell') self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards @@ -508,6 +437,9 @@ class RNN(Layer): return self._states @states.setter + # Automatic tracking catches "self._states" which adds an extra weight and + # breaks HDF5 checkpoints. + @trackable.no_automatic_dependency_tracking def states(self, states): self._states = states @@ -871,7 +803,8 @@ class RNN(Layer): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. - if isinstance(inputs, list): + if (isinstance(inputs, collections.Sequence) + and not isinstance(inputs, tuple)): # get initial_state from full input spec # as they could be copied to multiple GPU. if self._num_constants is None: @@ -988,36 +921,6 @@ class RNN(Layer): layer._num_constants = num_constants return layer - @property - def trainable_weights(self): - if not self.trainable: - return [] - if isinstance(self.cell, Layer): - return self.cell.trainable_weights - return [] - - @property - def non_trainable_weights(self): - if isinstance(self.cell, Layer): - if not self.trainable: - return self.cell.weights - return self.cell.non_trainable_weights - return [] - - @property - def losses(self): - layer_losses = super(RNN, self).losses - if isinstance(self.cell, Layer): - return self.cell.losses + layer_losses - return layer_losses - - @property - def updates(self): - updates = [] - if isinstance(self.cell, Layer): - updates += self.cell.updates - return updates + self._updates - @keras_export('keras.layers.AbstractRNNCell') class AbstractRNNCell(Layer): @@ -2262,8 +2165,6 @@ class UnifiedGRU(DropoutRNNCellMixin, GRU): call of the cell. """ - _setattr_tracking = False # TODO(allenl): Figure out why this is needed. - def __init__(self, units, activation='tanh', @@ -3666,3 +3567,4 @@ def _runtime(runtime_name): with ops.device('/cpu:0'): return constant_op.constant( runtime_name, dtype=dtypes.string, name='runtime') + diff --git a/tensorflow/python/keras/layers/unified_lstm_test.py b/tensorflow/python/keras/layers/unified_lstm_test.py index 316ce74d801..1e94b30ba39 100644 --- a/tensorflow/python/keras/layers/unified_lstm_test.py +++ b/tensorflow/python/keras/layers/unified_lstm_test.py @@ -43,6 +43,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import gradient_descent +from tensorflow.python.util import nest # Global config for grappler setting that is used for graph mode test. @@ -184,6 +185,7 @@ class UnifiedLSTMTest(keras_parameterized.TestCase): layer = keras.layers.UnifiedLSTM(units, stateful=True) layer.build((num_samples, timesteps, embedding_dim)) + initial_weight_count = len(layer.weights) layer.reset_states() assert len(layer.states) == num_states assert layer.states[0] is not None @@ -205,6 +207,12 @@ class UnifiedLSTMTest(keras_parameterized.TestCase): with self.assertRaises(ValueError): layer.reset_states([1] * (len(layer.states) + 1)) + self.assertEqual(initial_weight_count, len(layer.weights)) + # Variables in "states" shouldn't show up in .weights + layer.states = nest.map_structure(variables.Variable, values) + layer.reset_states() + self.assertEqual(initial_weight_count, len(layer.weights)) + def test_specify_state_with_masking(self): num_states = 2 timesteps = 3 diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 01b801330df..7fd375fbbe7 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -30,7 +30,6 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export @@ -47,7 +46,6 @@ class Wrapper(Layer): layer: The layer to be wrapped. """ - @trackable.no_automatic_dependency_tracking def __init__(self, layer, **kwargs): assert isinstance(layer, Layer) self.layer = layer @@ -67,36 +65,6 @@ class Wrapper(Layer): else: return None - @property - def trainable(self): - return self.layer.trainable - - @trainable.setter - def trainable(self, value): - self.layer.trainable = value - - @property - def trainable_weights(self): - return self.layer.trainable_weights - - @property - def non_trainable_weights(self): - return self.layer.non_trainable_weights - - @property - def updates(self): - return self.layer.updates + self._updates - - @property - def losses(self): - return self.layer.losses + self._losses - - def get_weights(self): - return self.layer.get_weights() - - def set_weights(self, weights): - self.layer.set_weights(weights) - def get_config(self): config = { 'layer': { @@ -180,7 +148,6 @@ class TimeDistributed(Wrapper): '`Layer` instance. You passed: {input}'.format(input=layer)) super(TimeDistributed, self).__init__(layer, **kwargs) self.supports_masking = True - self._track_trackable(layer, name='layer') # It is safe to use the fast, reshape-based approach with all of our # built-in Layers. @@ -407,7 +374,6 @@ class Bidirectional(Wrapper): ``` """ - @trackable.no_automatic_dependency_tracking def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): if not isinstance(layer, Layer): raise ValueError( @@ -438,28 +404,12 @@ class Bidirectional(Wrapper): self.supports_masking = True self._trainable = True self._num_constants = None + # We don't want to track `layer` since we're already tracking the two copies + # of it we actually run. + self._setattr_tracking = False super(Bidirectional, self).__init__(layer, **kwargs) + self._setattr_tracking = True self.input_spec = layer.input_spec - self._track_trackable(self.forward_layer, name='forward_layer') - self._track_trackable(self.backward_layer, name='backward_layer') - - @property - def trainable(self): - return self._trainable - - @trainable.setter - def trainable(self, value): - self._trainable = value - self.forward_layer.trainable = value - self.backward_layer.trainable = value - - def get_weights(self): - return self.forward_layer.get_weights() + self.backward_layer.get_weights() - - def set_weights(self, weights): - nw = len(weights) - self.forward_layer.set_weights(weights[:nw // 2]) - self.backward_layer.set_weights(weights[nw // 2:]) @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): @@ -653,32 +603,6 @@ class Bidirectional(Wrapper): return [output_mask] + state_mask * 2 return output_mask - @property - def trainable_weights(self): - if hasattr(self.forward_layer, 'trainable_weights'): - return (self.forward_layer.trainable_weights + - self.backward_layer.trainable_weights) - return [] - - @property - def non_trainable_weights(self): - if hasattr(self.forward_layer, 'non_trainable_weights'): - return (self.forward_layer.non_trainable_weights + - self.backward_layer.non_trainable_weights) - return [] - - @property - def updates(self): - if hasattr(self.forward_layer, 'updates'): - return self.forward_layer.updates + self.backward_layer.updates - return [] - - @property - def losses(self): - if hasattr(self.forward_layer, 'losses'): - return self.forward_layer.losses + self.backward_layer.losses - return [] - @property def constraints(self): constraints = {} diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 4659ca924c2..bb54adf2c76 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -90,7 +90,8 @@ class TimeDistributedTest(test.TestCase): # check whether the model variables are present in the # trackable list of objects - checkpointed_objects = set(trackable_util.list_objects(model)) + checkpointed_objects = object_identity.ObjectIdentitySet( + trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt index 95eb6f69ecc..43af4aa1ec1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt @@ -69,10 +69,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt index 009ecca9a7f..709aac579db 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt @@ -65,10 +65,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt index 6604ac05d91..ee06ae5059d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt @@ -64,10 +64,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt index 95eb6f69ecc..43af4aa1ec1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt @@ -69,10 +69,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt index 009ecca9a7f..709aac579db 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt @@ -65,10 +65,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt index 6604ac05d91..ee06ae5059d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt @@ -64,10 +64,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: ""