Extend tf.train.CheckpointManager functionality:

1. Add a `restore_or_initialize` method into CheckpointManager to support initializing from a checkpoint if possible.
2. Support optionally saving checkpoints based on interval steps.

PiperOrigin-RevId: 294816552
Change-Id: I47f4955e75677b02c10b4baddf03d78822dea6ae
This commit is contained in:
Ruoxin Sang 2020-02-12 19:34:34 -08:00 committed by TensorFlower Gardener
parent 8cad6f59e1
commit 73c31e6c97
4 changed files with 228 additions and 7 deletions

View File

@ -40,6 +40,13 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
def _evaluate(tensor):
"""Returns the numpy value of a tensor."""
if context.executing_eagerly():
return tensor.numpy()
return ops.get_default_session().run(tensor)
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
@ -529,7 +536,10 @@ class CheckpointManager(object):
directory,
max_to_keep,
keep_checkpoint_every_n_hours=None,
checkpoint_name="ckpt"):
checkpoint_name="ckpt",
step_counter=None,
checkpoint_interval=None,
init_fn=None):
"""Configure a `CheckpointManager` for use in `directory`.
If a `CheckpointManager` was previously used in `directory`, its
@ -550,6 +560,28 @@ class CheckpointManager(object):
`CheckpointManager` instantiated in `directory` (subject to its
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
`CheckpointManager` can be also used for initializing the model if
there is no checkpoints for restoring in `directory`. An example usage is:
>>> import tempfile
>>> tmp_dir = tempfile.mkdtemp()
>>> checkpoint = tf.train.Checkpoint()
>>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
>>> def init_fn():
... # Partially restore the checkpoint from `init_path`.
... checkpoint.restore(init_path)
>>> manager = tf.train.CheckpointManager(
... checkpoint,
... directory=os.path.join(tmp_dir, 'ckpt'),
... max_to_keep=None,
... init_fn=init_fn)
>>> # `restore_or_initialize` will call `init_fn` if there is no existing
>>> # checkpoint in `directory`.
>>> manager.restore_or_initialize()
Args:
checkpoint: The `tf.train.Checkpoint` instance to save and manage
checkpoints for.
@ -569,6 +601,12 @@ class CheckpointManager(object):
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
default setting of `None` does not preserve any checkpoints in this way.
checkpoint_name: Custom name for the checkpoint file.
step_counter: A `tf.Variable` instance for checking the current step
counter value, in case users want to save checkpoints every N steps.
checkpoint_interval: An integer, indicates the minimum step interval
between two checkpoints.
init_fn: Callable. A function to do customized intialization if no
checkpoints are in the directory.
Raises:
ValueError: If `max_to_keep` is not a positive integer.
@ -584,6 +622,16 @@ class CheckpointManager(object):
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self._directory = directory
self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
self._init_fn = init_fn
if checkpoint_interval is not None:
if step_counter is None:
raise ValueError("`step_counter` should be passed if "
"`checkpoint_interval` is not None.")
self._last_checkpoint_step = None
self._step_counter = step_counter
self._checkpoint_interval = checkpoint_interval
recovered_state = get_checkpoint_state(directory)
current_clock = time.time()
self._maybe_delete = collections.OrderedDict()
@ -619,6 +667,10 @@ class CheckpointManager(object):
def directory(self):
return self._directory
@property
def checkpoint_interval(self):
return self._checkpoint_interval
@property
def latest_checkpoint(self):
"""The prefix of the most recent checkpoint in `directory`.
@ -690,7 +742,12 @@ class CheckpointManager(object):
"""
return self._checkpoint_prefix
def save(self, checkpoint_number=None):
@property
def checkpoint(self):
"""Returns the `tf.train.Checkpoint` object."""
return self._checkpoint
def save(self, checkpoint_number=None, check_interval=True):
"""Creates a new checkpoint and manages it.
Args:
@ -700,11 +757,27 @@ class CheckpointManager(object):
`checkpoint_number` is provided, `save_counter` is still incremented. A
user-provided `checkpoint_number` is not incremented even if it is a
`Variable`.
check_interval: An optional boolean. The argument is only effective when
`checkpoint_interval` is passed into the manager. If `True`, the manager
will only save the checkpoint if the interval between checkpoints is
larger than `checkpoint_interval`. Otherwise it will always save the
checkpoint unless a checkpoint has already been saved for the current
step.
Returns:
The path to the new checkpoint. It is also recorded in the `checkpoints`
and `latest_checkpoint` properties.
and `latest_checkpoint` properties. `None` if no checkpoint is saved.
"""
if self._checkpoint_interval is not None:
current_step = _evaluate(self._step_counter)
if self._last_checkpoint_step is not None:
if current_step == self._last_checkpoint_step:
return None
if check_interval and current_step < (
self._last_checkpoint_step + self._checkpoint_interval):
return None
self._last_checkpoint_step = current_step
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
# slightly with a custom numbering option.
if context.executing_eagerly():
@ -749,3 +822,31 @@ class CheckpointManager(object):
# checkpoints.
self._record_state()
return save_path
def restore_or_initialize(self):
"""Restore items in `checkpoint` from the latest checkpoint file.
This method will first try to restore from the most recent checkpoint in
`directory`. If no checkpoints exist in `directory`, and `init_fn` is
specified, this method will call `init_fn` to do customized
initialization. This can be used to support initialization from pretrained
models.
Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
a load status object that users can run assertions on
(e.g. assert_consumed()). Thus to run assertions, users should directly use
`tf.train.Checkpoint.restore()` method.
Returns:
The restored checkpoint path if the lastest checkpoint is found and
restored. Otherwise None.
"""
if self._latest_checkpoint is not None:
self._checkpoint.restore(self._latest_checkpoint)
if self._checkpoint_interval is not None:
self._last_checkpoint_step = _evaluate(self._step_counter)
return self._latest_checkpoint
if self._init_fn is not None:
self._init_fn()
return None

View File

@ -571,6 +571,102 @@ class CheckpointManagerTest(test.TestCase):
path = manager.save(checkpoint_number=5)
self.assertEqual(os.path.basename(path), "ckpt-5")
@test_util.run_in_graph_and_eager_modes
def testRestoreOrInitialize(self):
directory = self.get_temp_dir()
# Create a checkpoint for initializing.
init_prefix = os.path.join(directory, "init")
init_v = variables.Variable(2.0)
init_ckpt = util.Checkpoint(v=init_v)
self.evaluate(init_v.initializer)
init_path = init_ckpt.save(init_prefix)
# Create the checkpoint manager.
ckpt_dir = os.path.join(directory, "ckpt")
v = variables.Variable(1.0)
checkpoint = util.Checkpoint(v=v)
manager = checkpoint_management.CheckpointManager(
checkpoint,
ckpt_dir,
max_to_keep=None,
init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
self.evaluate(v.initializer)
# First call should call `init_fn`.
self.assertIsNone(manager.restore_or_initialize())
self.assertEqual(2.0, self.evaluate(v))
# Save a checkpoint and second call should restore from the checkpoints.
manager.save()
self.assertIsNotNone(manager.restore_or_initialize())
@test_util.run_in_graph_and_eager_modes
def testCheckpointInterval(self):
v = variables.Variable(1.0)
step_counter = variables.Variable(0)
self.evaluate([v.initializer, step_counter.initializer])
checkpoint = util.Checkpoint(v=v)
manager = checkpoint_management.CheckpointManager(
checkpoint,
self.get_temp_dir(),
max_to_keep=None,
step_counter=step_counter,
checkpoint_interval=2)
# step_counter: 0, save an initial checkpoint.
path = manager.save(check_interval=True)
self.assertTrue(checkpoint_management.checkpoint_exists(path))
# step_counter: 1, no checkpoint saved.
self.evaluate(step_counter.assign_add(1))
path = manager.save(check_interval=True)
self.assertIsNone(path)
# step_counter: 2, checkpoint saved.
self.evaluate(step_counter.assign_add(1))
path = manager.save(check_interval=True)
self.assertTrue(checkpoint_management.checkpoint_exists(path))
# no checkpoint saved when calling `save` with the same step counter.
path = manager.save(check_interval=True)
self.assertIsNone(path)
# step_counter: 3, no checkpoint saved.
self.evaluate(step_counter.assign_add(1))
path = manager.save(check_interval=True)
self.assertIsNone(path)
# Always save the checkpoint.
path = manager.save(check_interval=False)
self.assertTrue(checkpoint_management.checkpoint_exists(path))
@test_util.run_in_graph_and_eager_modes
def testCheckpointIntervalWithRestore(self):
directory = self.get_temp_dir()
v = variables.Variable(1.0)
step_counter = variables.Variable(0)
self.evaluate([v.initializer, step_counter.initializer])
# Prepare a checkpoint.
checkpoint = util.Checkpoint(v=v)
checkpoint.save(os.path.join(directory, "ckpt"))
manager = checkpoint_management.CheckpointManager(
checkpoint,
directory,
max_to_keep=None,
step_counter=step_counter,
checkpoint_interval=2)
# Restore from the checkpoint.
self.assertIsNotNone(manager.restore_or_initialize())
# step_counter: 0, no checkpoint saved because it is restored from the
# checkpoint with the same step.
path = manager.save()
self.assertIsNone(path)
if __name__ == "__main__":
test.main()

View File

@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
tf_class {
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
is_instance: "<type \'object\'>"
member {
name: "checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "checkpoint_interval"
mtype: "<type \'property\'>"
}
member {
name: "checkpoints"
mtype: "<type \'property\'>"
@ -16,10 +24,14 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "restore_or_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "save"
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
}

View File

@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
tf_class {
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
is_instance: "<type \'object\'>"
member {
name: "checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "checkpoint_interval"
mtype: "<type \'property\'>"
}
member {
name: "checkpoints"
mtype: "<type \'property\'>"
@ -16,10 +24,14 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "restore_or_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "save"
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
}