Fix checkpointing tests.
1. Add calls to `run_restore_ops` to ensure that the restore ops are executed when using checkpoints in graph mode. 2. Ensure that layer orders are the same between the saved model and restored model. Otherwise there will be a race condition when restoring the checkpoint values. PiperOrigin-RevId: 288719809 Change-Id: I6f87a481e9fe3ea1e8ebc667cfea61dcb0716236
This commit is contained in:
parent
18845a4659
commit
2b3296441b
@ -1099,7 +1099,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
self._weight_loading_test_template(SubclassedModel)
|
||||
|
||||
def _new_layer_weight_loading_test_template(
|
||||
self, first_model_fn, second_model_fn, restore_init_fn):
|
||||
self, first_model_fn, second_model_fn):
|
||||
with self.cached_session() as session:
|
||||
model = first_model_fn()
|
||||
temp_dir = self.get_temp_dir()
|
||||
@ -1122,12 +1122,13 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
|
||||
second_model = second_model_fn()
|
||||
second_model.load_weights(prefix)
|
||||
status = second_model.load_weights(prefix)
|
||||
second_model(x)
|
||||
self.evaluate(restore_init_fn(second_model))
|
||||
status.run_restore_ops()
|
||||
second_model.save_weights(prefix)
|
||||
# Check that the second model's checkpoint loads into the original model
|
||||
model.load_weights(prefix)
|
||||
status = model.load_weights(prefix)
|
||||
status.run_restore_ops(session)
|
||||
y = self.evaluate(model(x))
|
||||
self.assertAllClose(ref_y, y)
|
||||
|
||||
@ -1144,12 +1145,9 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
y = keras.layers.Dense(1, name='second')(x)
|
||||
b = keras.layers.Dense(3, name='secondjr')(y)
|
||||
return keras.models.Model(a, b)
|
||||
def _restore_init_fn(restore_model):
|
||||
return [v.initializer for v in restore_model.layers[-1].variables]
|
||||
|
||||
self._new_layer_weight_loading_test_template(
|
||||
_save_graph_model, _restore_graph_model,
|
||||
_restore_init_fn)
|
||||
_save_graph_model, _restore_graph_model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_weight_loading_graph_model_added_no_weight_layer(self):
|
||||
@ -1161,16 +1159,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
def _restore_graph_model():
|
||||
a = keras.layers.Input(shape=(2,))
|
||||
x = keras.layers.Dense(3, name='first')(a)
|
||||
y = keras.layers.Dropout(rate=0.1)(x)
|
||||
b = keras.layers.Dense(1, name='second')(y)
|
||||
return keras.models.Model(a, b)
|
||||
def _restore_init_fn(restore_model):
|
||||
del restore_model # unused
|
||||
return []
|
||||
b = keras.layers.Dense(1, name='second')(x)
|
||||
y = keras.layers.Dropout(rate=0.1)(b)
|
||||
return keras.models.Model(a, y)
|
||||
|
||||
self._new_layer_weight_loading_test_template(
|
||||
_save_graph_model, _restore_graph_model,
|
||||
_restore_init_fn)
|
||||
_save_graph_model, _restore_graph_model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_weight_loading_subclassed_model_added_layer(self):
|
||||
@ -1186,12 +1180,8 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
def call(self, a):
|
||||
return self.b_layer(self.y_layer(self.x_layer(a)))
|
||||
|
||||
def _restore_init_fn(restore_model):
|
||||
return [v.initializer for v in restore_model.y_layer.variables]
|
||||
|
||||
self._new_layer_weight_loading_test_template(
|
||||
SubclassedModel, SubclassedModelRestore,
|
||||
_restore_init_fn)
|
||||
SubclassedModel, SubclassedModelRestore)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_incompatible_checkpoint(self):
|
||||
|
@ -352,7 +352,6 @@ class CheckpointPosition(object):
|
||||
if serialized_tensor.checkpoint_key not in saveable.name:
|
||||
saveable = None
|
||||
del saveables_cache[self.trackable]
|
||||
break
|
||||
if saveable is None:
|
||||
# If there was no cached SaveableObject, we should check if the Python
|
||||
# object has the attribute.
|
||||
|
Loading…
x
Reference in New Issue
Block a user