Fix bug with loading nested model with trainable/nontrainable weights.
Changed the test to the example from #27769. PiperOrigin-RevId: 254305891
This commit is contained in:
parent
7e5a151438
commit
f42549a91a
@ -272,30 +272,30 @@ def preprocess_weights_for_loading(layer,
|
||||
Returns:
|
||||
A list of weights values (Numpy arrays).
|
||||
"""
|
||||
new_weights = []
|
||||
# trainable weights
|
||||
for sublayer in layer.layers:
|
||||
num_weights = len(sublayer.trainable_weights)
|
||||
if num_weights > 0:
|
||||
new_weights.extend(preprocess_weights_for_loading(
|
||||
layer=sublayer,
|
||||
weights=weights[:num_weights],
|
||||
original_keras_version=original_keras_version,
|
||||
original_backend=original_backend))
|
||||
weights = weights[num_weights:]
|
||||
trainable_weights = weights[:len(layer.trainable_weights)]
|
||||
non_trainable_weights = weights[len(layer.trainable_weights):]
|
||||
|
||||
new_trainable_weights = []
|
||||
new_non_trainable_weights = []
|
||||
|
||||
# non-trainable weights
|
||||
for sublayer in layer.layers:
|
||||
num_weights = len([l for l in sublayer.weights
|
||||
if l not in sublayer.trainable_weights])
|
||||
if num_weights > 0:
|
||||
new_weights.extend(preprocess_weights_for_loading(
|
||||
num_trainable_weights = len(sublayer.trainable_weights)
|
||||
num_non_trainable_weights = len(sublayer.non_trainable_weights)
|
||||
if sublayer.weights:
|
||||
preprocessed = preprocess_weights_for_loading(
|
||||
layer=sublayer,
|
||||
weights=weights[:num_weights],
|
||||
weights=(trainable_weights[:num_trainable_weights] +
|
||||
non_trainable_weights[:num_non_trainable_weights]),
|
||||
original_keras_version=original_keras_version,
|
||||
original_backend=original_backend))
|
||||
weights = weights[num_weights:]
|
||||
return new_weights
|
||||
original_backend=original_backend)
|
||||
new_trainable_weights.extend(preprocessed[:num_trainable_weights])
|
||||
new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])
|
||||
|
||||
trainable_weights = trainable_weights[num_trainable_weights:]
|
||||
non_trainable_weights = non_trainable_weights[
|
||||
num_non_trainable_weights:]
|
||||
|
||||
return new_trainable_weights + new_non_trainable_weights
|
||||
|
||||
# Convert layers nested in Bidirectional/Model/Sequential.
|
||||
# Both transformation should be ran for both Keras 1->2 conversion
|
||||
|
@ -265,33 +265,33 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
|
||||
num_hidden = 5
|
||||
input_dim = 3
|
||||
batch_size = 5
|
||||
num_classes = 2
|
||||
shape = (None, None, 3)
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
|
||||
model.add(keras.layers.Dense(num_classes))
|
||||
def gen_model():
|
||||
|
||||
nested_model = keras.models.Sequential()
|
||||
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
|
||||
nested_model.add(keras.layers.Dense(num_classes))
|
||||
model.add(nested_model)
|
||||
def seq_model():
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Conv2D(3, 1, input_shape=shape),
|
||||
keras.layers.BatchNormalization()])
|
||||
return model
|
||||
|
||||
x = np.random.random((batch_size, input_dim))
|
||||
x = inner_inputs = keras.layers.Input((None, None, 3))
|
||||
x = seq_model()(x)
|
||||
x = seq_model()(x)
|
||||
inner_model = keras.models.Model(inner_inputs, x)
|
||||
|
||||
inputs = keras.layers.Input(shape)
|
||||
return keras.models.Model(inputs, inner_model(inputs))
|
||||
|
||||
model = gen_model()
|
||||
x = np.random.random((batch_size, 1, 1, 3))
|
||||
ref_y = model.predict(x)
|
||||
|
||||
model.save_weights(h5_path)
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
|
||||
model.add(keras.layers.Dense(num_classes))
|
||||
nested_model = keras.models.Sequential()
|
||||
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
|
||||
nested_model.add(keras.layers.Dense(num_classes))
|
||||
model.add(nested_model)
|
||||
model = gen_model()
|
||||
model.load_weights(h5_path)
|
||||
y = model.predict(x)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user