Fix bug with loading nested model with trainable/nontrainable weights.
Changed the test to the example from #27769. PiperOrigin-RevId: 254305891
This commit is contained in:
parent
7e5a151438
commit
f42549a91a
@ -272,30 +272,30 @@ def preprocess_weights_for_loading(layer,
|
|||||||
Returns:
|
Returns:
|
||||||
A list of weights values (Numpy arrays).
|
A list of weights values (Numpy arrays).
|
||||||
"""
|
"""
|
||||||
new_weights = []
|
trainable_weights = weights[:len(layer.trainable_weights)]
|
||||||
# trainable weights
|
non_trainable_weights = weights[len(layer.trainable_weights):]
|
||||||
for sublayer in layer.layers:
|
|
||||||
num_weights = len(sublayer.trainable_weights)
|
new_trainable_weights = []
|
||||||
if num_weights > 0:
|
new_non_trainable_weights = []
|
||||||
new_weights.extend(preprocess_weights_for_loading(
|
|
||||||
layer=sublayer,
|
|
||||||
weights=weights[:num_weights],
|
|
||||||
original_keras_version=original_keras_version,
|
|
||||||
original_backend=original_backend))
|
|
||||||
weights = weights[num_weights:]
|
|
||||||
|
|
||||||
# non-trainable weights
|
|
||||||
for sublayer in layer.layers:
|
for sublayer in layer.layers:
|
||||||
num_weights = len([l for l in sublayer.weights
|
num_trainable_weights = len(sublayer.trainable_weights)
|
||||||
if l not in sublayer.trainable_weights])
|
num_non_trainable_weights = len(sublayer.non_trainable_weights)
|
||||||
if num_weights > 0:
|
if sublayer.weights:
|
||||||
new_weights.extend(preprocess_weights_for_loading(
|
preprocessed = preprocess_weights_for_loading(
|
||||||
layer=sublayer,
|
layer=sublayer,
|
||||||
weights=weights[:num_weights],
|
weights=(trainable_weights[:num_trainable_weights] +
|
||||||
|
non_trainable_weights[:num_non_trainable_weights]),
|
||||||
original_keras_version=original_keras_version,
|
original_keras_version=original_keras_version,
|
||||||
original_backend=original_backend))
|
original_backend=original_backend)
|
||||||
weights = weights[num_weights:]
|
new_trainable_weights.extend(preprocessed[:num_trainable_weights])
|
||||||
return new_weights
|
new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])
|
||||||
|
|
||||||
|
trainable_weights = trainable_weights[num_trainable_weights:]
|
||||||
|
non_trainable_weights = non_trainable_weights[
|
||||||
|
num_non_trainable_weights:]
|
||||||
|
|
||||||
|
return new_trainable_weights + new_non_trainable_weights
|
||||||
|
|
||||||
# Convert layers nested in Bidirectional/Model/Sequential.
|
# Convert layers nested in Bidirectional/Model/Sequential.
|
||||||
# Both transformation should be ran for both Keras 1->2 conversion
|
# Both transformation should be ran for both Keras 1->2 conversion
|
||||||
|
@ -265,33 +265,33 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||||||
self.addCleanup(shutil.rmtree, temp_dir)
|
self.addCleanup(shutil.rmtree, temp_dir)
|
||||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||||
|
|
||||||
num_hidden = 5
|
|
||||||
input_dim = 3
|
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
num_classes = 2
|
shape = (None, None, 3)
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
model = keras.models.Sequential()
|
def gen_model():
|
||||||
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
|
|
||||||
model.add(keras.layers.Dense(num_classes))
|
|
||||||
|
|
||||||
nested_model = keras.models.Sequential()
|
def seq_model():
|
||||||
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
|
model = keras.models.Sequential([
|
||||||
nested_model.add(keras.layers.Dense(num_classes))
|
keras.layers.Conv2D(3, 1, input_shape=shape),
|
||||||
model.add(nested_model)
|
keras.layers.BatchNormalization()])
|
||||||
|
return model
|
||||||
|
|
||||||
x = np.random.random((batch_size, input_dim))
|
x = inner_inputs = keras.layers.Input((None, None, 3))
|
||||||
|
x = seq_model()(x)
|
||||||
|
x = seq_model()(x)
|
||||||
|
inner_model = keras.models.Model(inner_inputs, x)
|
||||||
|
|
||||||
|
inputs = keras.layers.Input(shape)
|
||||||
|
return keras.models.Model(inputs, inner_model(inputs))
|
||||||
|
|
||||||
|
model = gen_model()
|
||||||
|
x = np.random.random((batch_size, 1, 1, 3))
|
||||||
ref_y = model.predict(x)
|
ref_y = model.predict(x)
|
||||||
|
|
||||||
model.save_weights(h5_path)
|
model.save_weights(h5_path)
|
||||||
|
|
||||||
model = keras.models.Sequential()
|
model = gen_model()
|
||||||
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
|
|
||||||
model.add(keras.layers.Dense(num_classes))
|
|
||||||
nested_model = keras.models.Sequential()
|
|
||||||
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
|
|
||||||
nested_model.add(keras.layers.Dense(num_classes))
|
|
||||||
model.add(nested_model)
|
|
||||||
model.load_weights(h5_path)
|
model.load_weights(h5_path)
|
||||||
y = model.predict(x)
|
y = model.predict(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user