Merge pull request #32257 from omalleyt12/cherrypicks_20IEW

[r2.0-CherryPick] Deduplicate Keras weights
This commit is contained in:
Goldie Gadde 2019-09-06 14:14:22 -07:00 committed by GitHub
commit 5a580681ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 57 additions and 26 deletions

View File

@ -905,7 +905,7 @@ class Layer(module.Module):
def trainable_weights(self):
if self.trainable:
nested = self._gather_children_attribute('trainable_weights')
return self._trainable_weights + nested
return self._dedup_weights(self._trainable_weights + nested)
else:
return []
@ -913,10 +913,12 @@ class Layer(module.Module):
def non_trainable_weights(self):
if self.trainable:
nested = self._gather_children_attribute('non_trainable_weights')
return self._non_trainable_weights + nested
non_trainable_weights = self._non_trainable_weights + nested
else:
nested = self._gather_children_attribute('weights')
return self._trainable_weights + self._non_trainable_weights + nested
non_trainable_weights = (
self._trainable_weights + self._non_trainable_weights + nested)
return self._dedup_weights(non_trainable_weights)
@property
def weights(self):
@ -2452,14 +2454,13 @@ class Layer(module.Module):
serialization_cache))
return fns
@property
def _unique_trainable_weights(self):
"""Dedupe trainable weights while maintaining order as much as possible."""
trainable_weights = self.trainable_weights
def _dedup_weights(self, weights):
"""Dedupe weights while maintaining order as much as possible."""
output, seen_weights = [], object_identity.ObjectIdentitySet()
for w in trainable_weights:
for w in weights:
if w not in seen_weights:
output.append(w)
# Track the Variable's identity to avoid __eq__ issues.
seen_weights.add(w)
return output

View File

@ -472,6 +472,11 @@ class Network(base_layer.Layer):
Returns:
A list of variables.
"""
return self._dedup_weights(self._undeduplicated_weights)
@property
def _undeduplicated_weights(self):
"""Returns the undeduplicated list of all layer variables/weights."""
self._assert_weights_created()
weights = []
for layer in self._layers:
@ -535,18 +540,21 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
self._assert_weights_created()
return trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights)
return self._dedup_weights(
trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights))
@property
def non_trainable_weights(self):
self._assert_weights_created()
return trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights + self._trainable_weights)
return self._dedup_weights(
trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights +
self._trainable_weights))
@property
def input_spec(self):

View File

@ -386,7 +386,7 @@ class Model(network.Network):
self.predict_function = None
# Collected trainable weights, sorted in topological order.
self._collected_trainable_weights = self._unique_trainable_weights
self._collected_trainable_weights = self.trainable_weights
# Validate all variables were correctly created in distribution scope.
if self._distribution_strategy and not self._compile_distribution:
@ -1535,7 +1535,7 @@ class Model(network.Network):
# Set metric attributes on model.
self._set_metric_attributes()
self._collected_trainable_weights = self._unique_trainable_weights
self._collected_trainable_weights = self.trainable_weights
def _update_sample_weight_modes(self, sample_weights=None):
"""Updates sample weight modes based on training/eval inputs.
@ -2046,8 +2046,7 @@ class Model(network.Network):
if not hasattr(self, '_collected_trainable_weights'):
return
if (len(self._unique_trainable_weights) !=
len(self._collected_trainable_weights)):
if len(self.trainable_weights) != len(self._collected_trainable_weights):
logging.log_first_n(
logging.WARN, 'Discrepancy between trainable weights and collected'
' trainable weights, did you set `model.trainable`'

View File

@ -258,7 +258,7 @@ def _process_single_batch(model,
else:
scaled_total_loss = total_loss
if training:
trainable_weights = model._unique_trainable_weights
trainable_weights = model.trainable_weights
if trainable_weights:
# TODO(tanzheny) b/132690565: Provide mechanism for user to override
# model.train_on_batch.

View File

@ -904,6 +904,23 @@ class TrainingTest(keras_parameterized.TestCase):
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
def test_weight_deduplication_in_methods(self):
inp = keras.layers.Input(shape=(1,))
bn = keras.layers.BatchNormalization()
d = keras.layers.Dense(1)
m0 = keras.models.Model(inp, d(bn(inp)))
m1 = keras.models.Model(inp, d(bn(inp)))
x0 = m0(inp)
x1 = m1(inp)
x = keras.layers.Add()([x0, x1])
model = keras.models.Model(inp, x)
self.assertLen(model.trainable_weights, 4)
self.assertLen(model.non_trainable_weights, 2)
self.assertLen(model.weights, 6)
@keras_parameterized.run_all_keras_modes
def test_weight_deduplication(self):
class WatchingLayer(keras.layers.Layer):

View File

@ -102,8 +102,8 @@ class WideDeepModel(training.Model):
# This does not support gradient scaling and LossScaleOptimizer.
def _backwards(self, tape, loss):
linear_vars = self.linear_model._unique_trainable_weights # pylint: disable=protected-access
dnn_vars = self.dnn_model._unique_trainable_weights # pylint: disable=protected-access
linear_vars = self.linear_model.trainable_weights # pylint: disable=protected-access
dnn_vars = self.dnn_model.trainable_weights # pylint: disable=protected-access
linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
linear_optimizer, dnn_optimizer = self._get_optimizers()
linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
@ -134,11 +134,11 @@ class WideDeepModel(training.Model):
# Training updates
updates = []
linear_updates = linear_optimizer.get_updates(
params=self.linear_model._unique_trainable_weights, # pylint: disable=protected-access
params=self.linear_model.trainable_weights, # pylint: disable=protected-access
loss=self.total_loss)
updates += linear_updates
dnn_updates = dnn_optimizer.get_updates(
params=self.dnn_model._unique_trainable_weights, # pylint: disable=protected-access
params=self.dnn_model.trainable_weights, # pylint: disable=protected-access
loss=self.total_loss)
updates += dnn_updates
# Unconditional updates

View File

@ -75,6 +75,12 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
# TODO(psv) Add warning when we save models that contain non-serializable
# entities like metrics added using `add_metric` and losses added using
# `add_loss.`
if len(model.weights) != len(model._undeduplicated_weights):
logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
'This is usually caused by `Variable`s being shared by '
'Layers in the Model. These `Variable`s will be treated '
'as separate `Variable`s when the Model is restored. To '
'avoid this, please save with `save_format="tf"`.')
if not isinstance(filepath, h5py.File):
# If file exists and should not be overwritten.

View File

@ -235,7 +235,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model._unique_trainable_weights)
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)