diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 2fc669be2c8..917e871a7f0 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -905,7 +905,7 @@ class Layer(module.Module): def trainable_weights(self): if self.trainable: nested = self._gather_children_attribute('trainable_weights') - return self._trainable_weights + nested + return self._dedup_weights(self._trainable_weights + nested) else: return [] @@ -913,10 +913,12 @@ class Layer(module.Module): def non_trainable_weights(self): if self.trainable: nested = self._gather_children_attribute('non_trainable_weights') - return self._non_trainable_weights + nested + non_trainable_weights = self._non_trainable_weights + nested else: nested = self._gather_children_attribute('weights') - return self._trainable_weights + self._non_trainable_weights + nested + non_trainable_weights = ( + self._trainable_weights + self._non_trainable_weights + nested) + return self._dedup_weights(non_trainable_weights) @property def weights(self): @@ -2452,14 +2454,13 @@ class Layer(module.Module): serialization_cache)) return fns - @property - def _unique_trainable_weights(self): - """Dedupe trainable weights while maintaining order as much as possible.""" - trainable_weights = self.trainable_weights + def _dedup_weights(self, weights): + """Dedupe weights while maintaining order as much as possible.""" output, seen_weights = [], object_identity.ObjectIdentitySet() - for w in trainable_weights: + for w in weights: if w not in seen_weights: output.append(w) + # Track the Variable's identity to avoid __eq__ issues. seen_weights.add(w) return output diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index ff5a479a01a..561dc8a8ba7 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -472,6 +472,11 @@ class Network(base_layer.Layer): Returns: A list of variables. """ + return self._dedup_weights(self._undeduplicated_weights) + + @property + def _undeduplicated_weights(self): + """Returns the undeduplicated list of all layer variables/weights.""" self._assert_weights_created() weights = [] for layer in self._layers: @@ -535,18 +540,21 @@ class Network(base_layer.Layer): @property def trainable_weights(self): self._assert_weights_created() - return trackable_layer_utils.gather_trainable_weights( - trainable=self.trainable, - sub_layers=self._layers, - extra_variables=self._trainable_weights) + return self._dedup_weights( + trackable_layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self._layers, + extra_variables=self._trainable_weights)) @property def non_trainable_weights(self): self._assert_weights_created() - return trackable_layer_utils.gather_non_trainable_weights( - trainable=self.trainable, - sub_layers=self._layers, - extra_variables=self._non_trainable_weights + self._trainable_weights) + return self._dedup_weights( + trackable_layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self._layers, + extra_variables=self._non_trainable_weights + + self._trainable_weights)) @property def input_spec(self): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 5ca2a289755..f7e0710557a 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -386,7 +386,7 @@ class Model(network.Network): self.predict_function = None # Collected trainable weights, sorted in topological order. - self._collected_trainable_weights = self._unique_trainable_weights + self._collected_trainable_weights = self.trainable_weights # Validate all variables were correctly created in distribution scope. if self._distribution_strategy and not self._compile_distribution: @@ -1535,7 +1535,7 @@ class Model(network.Network): # Set metric attributes on model. self._set_metric_attributes() - self._collected_trainable_weights = self._unique_trainable_weights + self._collected_trainable_weights = self.trainable_weights def _update_sample_weight_modes(self, sample_weights=None): """Updates sample weight modes based on training/eval inputs. @@ -2046,8 +2046,7 @@ class Model(network.Network): if not hasattr(self, '_collected_trainable_weights'): return - if (len(self._unique_trainable_weights) != - len(self._collected_trainable_weights)): + if len(self.trainable_weights) != len(self._collected_trainable_weights): logging.log_first_n( logging.WARN, 'Discrepancy between trainable weights and collected' ' trainable weights, did you set `model.trainable`' diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index ab16efc3646..be1b2e89d90 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -258,7 +258,7 @@ def _process_single_batch(model, else: scaled_total_loss = total_loss if training: - trainable_weights = model._unique_trainable_weights + trainable_weights = model.trainable_weights if trainable_weights: # TODO(tanzheny) b/132690565: Provide mechanism for user to override # model.train_on_batch. diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index d40de4ba4de..d8c09918d8f 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -904,6 +904,23 @@ class TrainingTest(keras_parameterized.TestCase): x2 = model.predict(val_a) self.assertAllClose(x1, x2, atol=1e-7) + def test_weight_deduplication_in_methods(self): + inp = keras.layers.Input(shape=(1,)) + bn = keras.layers.BatchNormalization() + d = keras.layers.Dense(1) + + m0 = keras.models.Model(inp, d(bn(inp))) + m1 = keras.models.Model(inp, d(bn(inp))) + + x0 = m0(inp) + x1 = m1(inp) + x = keras.layers.Add()([x0, x1]) + + model = keras.models.Model(inp, x) + self.assertLen(model.trainable_weights, 4) + self.assertLen(model.non_trainable_weights, 2) + self.assertLen(model.weights, 6) + @keras_parameterized.run_all_keras_modes def test_weight_deduplication(self): class WatchingLayer(keras.layers.Layer): diff --git a/tensorflow/python/keras/premade/wide_deep.py b/tensorflow/python/keras/premade/wide_deep.py index ff5dd5e2ed3..7dc03247982 100644 --- a/tensorflow/python/keras/premade/wide_deep.py +++ b/tensorflow/python/keras/premade/wide_deep.py @@ -102,8 +102,8 @@ class WideDeepModel(training.Model): # This does not support gradient scaling and LossScaleOptimizer. def _backwards(self, tape, loss): - linear_vars = self.linear_model._unique_trainable_weights # pylint: disable=protected-access - dnn_vars = self.dnn_model._unique_trainable_weights # pylint: disable=protected-access + linear_vars = self.linear_model.trainable_weights # pylint: disable=protected-access + dnn_vars = self.dnn_model.trainable_weights # pylint: disable=protected-access linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars)) linear_optimizer, dnn_optimizer = self._get_optimizers() linear_optimizer.apply_gradients(zip(linear_grads, linear_vars)) @@ -134,11 +134,11 @@ class WideDeepModel(training.Model): # Training updates updates = [] linear_updates = linear_optimizer.get_updates( - params=self.linear_model._unique_trainable_weights, # pylint: disable=protected-access + params=self.linear_model.trainable_weights, # pylint: disable=protected-access loss=self.total_loss) updates += linear_updates dnn_updates = dnn_optimizer.get_updates( - params=self.dnn_model._unique_trainable_weights, # pylint: disable=protected-access + params=self.dnn_model.trainable_weights, # pylint: disable=protected-access loss=self.total_loss) updates += dnn_updates # Unconditional updates diff --git a/tensorflow/python/keras/saving/hdf5_format.py b/tensorflow/python/keras/saving/hdf5_format.py index d23db5b7763..8027cd36330 100644 --- a/tensorflow/python/keras/saving/hdf5_format.py +++ b/tensorflow/python/keras/saving/hdf5_format.py @@ -75,6 +75,12 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): # TODO(psv) Add warning when we save models that contain non-serializable # entities like metrics added using `add_metric` and losses added using # `add_loss.` + if len(model.weights) != len(model._undeduplicated_weights): + logging.warning('Found duplicated `Variable`s in Model\'s `weights`. ' + 'This is usually caused by `Variable`s being shared by ' + 'Layers in the Model. These `Variable`s will be treated ' + 'as separate `Variable`s when the Model is restored. To ' + 'avoid this, please save with `save_format="tf"`.') if not isinstance(filepath, h5py.File): # If file exists and should not be overwritten. diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 6fff75d080b..c0de7308e67 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -235,7 +235,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): if hasattr(model, '_collected_trainable_weights'): trainable_count = count_params(model._collected_trainable_weights) else: - trainable_count = count_params(model._unique_trainable_weights) + trainable_count = count_params(model.trainable_weights) non_trainable_count = count_params(model.non_trainable_weights)