Fix checkpointing tests.

1. Add calls to `run_restore_ops` to ensure that the restore ops are executed when using checkpoints in graph mode.
2. Ensure that layer orders are the same between the saved model and restored model. Otherwise there will be a race condition when restoring the checkpoint values.

PiperOrigin-RevId: 288719809
Change-Id: I6f87a481e9fe3ea1e8ebc667cfea61dcb0716236
This commit is contained in:
Katherine Wu 2020-01-08 10:03:59 -08:00 committed by TensorFlower Gardener
parent 18845a4659
commit 2b3296441b
2 changed files with 11 additions and 22 deletions

View File

@ -1099,7 +1099,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._weight_loading_test_template(SubclassedModel)
def _new_layer_weight_loading_test_template(
self, first_model_fn, second_model_fn, restore_init_fn):
self, first_model_fn, second_model_fn):
with self.cached_session() as session:
model = first_model_fn()
temp_dir = self.get_temp_dir()
@ -1122,12 +1122,13 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.addCleanup(shutil.rmtree, temp_dir)
second_model = second_model_fn()
second_model.load_weights(prefix)
status = second_model.load_weights(prefix)
second_model(x)
self.evaluate(restore_init_fn(second_model))
status.run_restore_ops()
second_model.save_weights(prefix)
# Check that the second model's checkpoint loads into the original model
model.load_weights(prefix)
status = model.load_weights(prefix)
status.run_restore_ops(session)
y = self.evaluate(model(x))
self.assertAllClose(ref_y, y)
@ -1144,12 +1145,9 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
y = keras.layers.Dense(1, name='second')(x)
b = keras.layers.Dense(3, name='secondjr')(y)
return keras.models.Model(a, b)
def _restore_init_fn(restore_model):
return [v.initializer for v in restore_model.layers[-1].variables]
self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model,
_restore_init_fn)
_save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model_added_no_weight_layer(self):
@ -1161,16 +1159,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def _restore_graph_model():
a = keras.layers.Input(shape=(2,))
x = keras.layers.Dense(3, name='first')(a)
y = keras.layers.Dropout(rate=0.1)(x)
b = keras.layers.Dense(1, name='second')(y)
return keras.models.Model(a, b)
def _restore_init_fn(restore_model):
del restore_model # unused
return []
b = keras.layers.Dense(1, name='second')(x)
y = keras.layers.Dropout(rate=0.1)(b)
return keras.models.Model(a, y)
self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model,
_restore_init_fn)
_save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes
def test_weight_loading_subclassed_model_added_layer(self):
@ -1186,12 +1180,8 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def call(self, a):
return self.b_layer(self.y_layer(self.x_layer(a)))
def _restore_init_fn(restore_model):
return [v.initializer for v in restore_model.y_layer.variables]
self._new_layer_weight_loading_test_template(
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
SubclassedModel, SubclassedModelRestore)
@test_util.run_in_graph_and_eager_modes
def test_incompatible_checkpoint(self):

View File

@ -352,7 +352,6 @@ class CheckpointPosition(object):
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
del saveables_cache[self.trackable]
break
if saveable is None:
# If there was no cached SaveableObject, we should check if the Python
# object has the attribute.