Merge pull request #32257 from omalleyt12/cherrypicks_20IEW
[r2.0-CherryPick] Deduplicate Keras weights
This commit is contained in:
commit
5a580681ad
@ -905,7 +905,7 @@ class Layer(module.Module):
|
||||
def trainable_weights(self):
|
||||
if self.trainable:
|
||||
nested = self._gather_children_attribute('trainable_weights')
|
||||
return self._trainable_weights + nested
|
||||
return self._dedup_weights(self._trainable_weights + nested)
|
||||
else:
|
||||
return []
|
||||
|
||||
@ -913,10 +913,12 @@ class Layer(module.Module):
|
||||
def non_trainable_weights(self):
|
||||
if self.trainable:
|
||||
nested = self._gather_children_attribute('non_trainable_weights')
|
||||
return self._non_trainable_weights + nested
|
||||
non_trainable_weights = self._non_trainable_weights + nested
|
||||
else:
|
||||
nested = self._gather_children_attribute('weights')
|
||||
return self._trainable_weights + self._non_trainable_weights + nested
|
||||
non_trainable_weights = (
|
||||
self._trainable_weights + self._non_trainable_weights + nested)
|
||||
return self._dedup_weights(non_trainable_weights)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
@ -2452,14 +2454,13 @@ class Layer(module.Module):
|
||||
serialization_cache))
|
||||
return fns
|
||||
|
||||
@property
|
||||
def _unique_trainable_weights(self):
|
||||
"""Dedupe trainable weights while maintaining order as much as possible."""
|
||||
trainable_weights = self.trainable_weights
|
||||
def _dedup_weights(self, weights):
|
||||
"""Dedupe weights while maintaining order as much as possible."""
|
||||
output, seen_weights = [], object_identity.ObjectIdentitySet()
|
||||
for w in trainable_weights:
|
||||
for w in weights:
|
||||
if w not in seen_weights:
|
||||
output.append(w)
|
||||
# Track the Variable's identity to avoid __eq__ issues.
|
||||
seen_weights.add(w)
|
||||
return output
|
||||
|
||||
|
@ -472,6 +472,11 @@ class Network(base_layer.Layer):
|
||||
Returns:
|
||||
A list of variables.
|
||||
"""
|
||||
return self._dedup_weights(self._undeduplicated_weights)
|
||||
|
||||
@property
|
||||
def _undeduplicated_weights(self):
|
||||
"""Returns the undeduplicated list of all layer variables/weights."""
|
||||
self._assert_weights_created()
|
||||
weights = []
|
||||
for layer in self._layers:
|
||||
@ -535,18 +540,21 @@ class Network(base_layer.Layer):
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
self._assert_weights_created()
|
||||
return trackable_layer_utils.gather_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._trainable_weights)
|
||||
return self._dedup_weights(
|
||||
trackable_layer_utils.gather_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._trainable_weights))
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
self._assert_weights_created()
|
||||
return trackable_layer_utils.gather_non_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._non_trainable_weights + self._trainable_weights)
|
||||
return self._dedup_weights(
|
||||
trackable_layer_utils.gather_non_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._non_trainable_weights +
|
||||
self._trainable_weights))
|
||||
|
||||
@property
|
||||
def input_spec(self):
|
||||
|
@ -386,7 +386,7 @@ class Model(network.Network):
|
||||
self.predict_function = None
|
||||
|
||||
# Collected trainable weights, sorted in topological order.
|
||||
self._collected_trainable_weights = self._unique_trainable_weights
|
||||
self._collected_trainable_weights = self.trainable_weights
|
||||
|
||||
# Validate all variables were correctly created in distribution scope.
|
||||
if self._distribution_strategy and not self._compile_distribution:
|
||||
@ -1535,7 +1535,7 @@ class Model(network.Network):
|
||||
# Set metric attributes on model.
|
||||
self._set_metric_attributes()
|
||||
|
||||
self._collected_trainable_weights = self._unique_trainable_weights
|
||||
self._collected_trainable_weights = self.trainable_weights
|
||||
|
||||
def _update_sample_weight_modes(self, sample_weights=None):
|
||||
"""Updates sample weight modes based on training/eval inputs.
|
||||
@ -2046,8 +2046,7 @@ class Model(network.Network):
|
||||
if not hasattr(self, '_collected_trainable_weights'):
|
||||
return
|
||||
|
||||
if (len(self._unique_trainable_weights) !=
|
||||
len(self._collected_trainable_weights)):
|
||||
if len(self.trainable_weights) != len(self._collected_trainable_weights):
|
||||
logging.log_first_n(
|
||||
logging.WARN, 'Discrepancy between trainable weights and collected'
|
||||
' trainable weights, did you set `model.trainable`'
|
||||
|
@ -258,7 +258,7 @@ def _process_single_batch(model,
|
||||
else:
|
||||
scaled_total_loss = total_loss
|
||||
if training:
|
||||
trainable_weights = model._unique_trainable_weights
|
||||
trainable_weights = model.trainable_weights
|
||||
if trainable_weights:
|
||||
# TODO(tanzheny) b/132690565: Provide mechanism for user to override
|
||||
# model.train_on_batch.
|
||||
|
@ -904,6 +904,23 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
|
||||
def test_weight_deduplication_in_methods(self):
|
||||
inp = keras.layers.Input(shape=(1,))
|
||||
bn = keras.layers.BatchNormalization()
|
||||
d = keras.layers.Dense(1)
|
||||
|
||||
m0 = keras.models.Model(inp, d(bn(inp)))
|
||||
m1 = keras.models.Model(inp, d(bn(inp)))
|
||||
|
||||
x0 = m0(inp)
|
||||
x1 = m1(inp)
|
||||
x = keras.layers.Add()([x0, x1])
|
||||
|
||||
model = keras.models.Model(inp, x)
|
||||
self.assertLen(model.trainable_weights, 4)
|
||||
self.assertLen(model.non_trainable_weights, 2)
|
||||
self.assertLen(model.weights, 6)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_weight_deduplication(self):
|
||||
class WatchingLayer(keras.layers.Layer):
|
||||
|
@ -102,8 +102,8 @@ class WideDeepModel(training.Model):
|
||||
|
||||
# This does not support gradient scaling and LossScaleOptimizer.
|
||||
def _backwards(self, tape, loss):
|
||||
linear_vars = self.linear_model._unique_trainable_weights # pylint: disable=protected-access
|
||||
dnn_vars = self.dnn_model._unique_trainable_weights # pylint: disable=protected-access
|
||||
linear_vars = self.linear_model.trainable_weights # pylint: disable=protected-access
|
||||
dnn_vars = self.dnn_model.trainable_weights # pylint: disable=protected-access
|
||||
linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
|
||||
linear_optimizer, dnn_optimizer = self._get_optimizers()
|
||||
linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
|
||||
@ -134,11 +134,11 @@ class WideDeepModel(training.Model):
|
||||
# Training updates
|
||||
updates = []
|
||||
linear_updates = linear_optimizer.get_updates(
|
||||
params=self.linear_model._unique_trainable_weights, # pylint: disable=protected-access
|
||||
params=self.linear_model.trainable_weights, # pylint: disable=protected-access
|
||||
loss=self.total_loss)
|
||||
updates += linear_updates
|
||||
dnn_updates = dnn_optimizer.get_updates(
|
||||
params=self.dnn_model._unique_trainable_weights, # pylint: disable=protected-access
|
||||
params=self.dnn_model.trainable_weights, # pylint: disable=protected-access
|
||||
loss=self.total_loss)
|
||||
updates += dnn_updates
|
||||
# Unconditional updates
|
||||
|
@ -75,6 +75,12 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
||||
# TODO(psv) Add warning when we save models that contain non-serializable
|
||||
# entities like metrics added using `add_metric` and losses added using
|
||||
# `add_loss.`
|
||||
if len(model.weights) != len(model._undeduplicated_weights):
|
||||
logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
|
||||
'This is usually caused by `Variable`s being shared by '
|
||||
'Layers in the Model. These `Variable`s will be treated '
|
||||
'as separate `Variable`s when the Model is restored. To '
|
||||
'avoid this, please save with `save_format="tf"`.')
|
||||
|
||||
if not isinstance(filepath, h5py.File):
|
||||
# If file exists and should not be overwritten.
|
||||
|
@ -235,7 +235,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
|
||||
if hasattr(model, '_collected_trainable_weights'):
|
||||
trainable_count = count_params(model._collected_trainable_weights)
|
||||
else:
|
||||
trainable_count = count_params(model._unique_trainable_weights)
|
||||
trainable_count = count_params(model.trainable_weights)
|
||||
|
||||
non_trainable_count = count_params(model.non_trainable_weights)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user