Extend tf.train.CheckpointManager
functionality:
1. Add a `restore_or_initialize` method into CheckpointManager to support initializing from a checkpoint if possible. 2. Support optionally saving checkpoints based on interval steps. PiperOrigin-RevId: 294816552 Change-Id: I47f4955e75677b02c10b4baddf03d78822dea6ae
This commit is contained in:
parent
8cad6f59e1
commit
73c31e6c97
@ -40,6 +40,13 @@ from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _evaluate(tensor):
|
||||
"""Returns the numpy value of a tensor."""
|
||||
if context.executing_eagerly():
|
||||
return tensor.numpy()
|
||||
return ops.get_default_session().run(tensor)
|
||||
|
||||
|
||||
def _GetCheckpointFilename(save_dir, latest_filename):
|
||||
"""Returns a filename for storing the CheckpointState.
|
||||
|
||||
@ -529,7 +536,10 @@ class CheckpointManager(object):
|
||||
directory,
|
||||
max_to_keep,
|
||||
keep_checkpoint_every_n_hours=None,
|
||||
checkpoint_name="ckpt"):
|
||||
checkpoint_name="ckpt",
|
||||
step_counter=None,
|
||||
checkpoint_interval=None,
|
||||
init_fn=None):
|
||||
"""Configure a `CheckpointManager` for use in `directory`.
|
||||
|
||||
If a `CheckpointManager` was previously used in `directory`, its
|
||||
@ -550,6 +560,28 @@ class CheckpointManager(object):
|
||||
`CheckpointManager` instantiated in `directory` (subject to its
|
||||
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
|
||||
|
||||
`CheckpointManager` can be also used for initializing the model if
|
||||
there is no checkpoints for restoring in `directory`. An example usage is:
|
||||
|
||||
>>> import tempfile
|
||||
|
||||
>>> tmp_dir = tempfile.mkdtemp()
|
||||
>>> checkpoint = tf.train.Checkpoint()
|
||||
>>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
|
||||
|
||||
>>> def init_fn():
|
||||
... # Partially restore the checkpoint from `init_path`.
|
||||
... checkpoint.restore(init_path)
|
||||
|
||||
>>> manager = tf.train.CheckpointManager(
|
||||
... checkpoint,
|
||||
... directory=os.path.join(tmp_dir, 'ckpt'),
|
||||
... max_to_keep=None,
|
||||
... init_fn=init_fn)
|
||||
>>> # `restore_or_initialize` will call `init_fn` if there is no existing
|
||||
>>> # checkpoint in `directory`.
|
||||
>>> manager.restore_or_initialize()
|
||||
|
||||
Args:
|
||||
checkpoint: The `tf.train.Checkpoint` instance to save and manage
|
||||
checkpoints for.
|
||||
@ -569,6 +601,12 @@ class CheckpointManager(object):
|
||||
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
|
||||
default setting of `None` does not preserve any checkpoints in this way.
|
||||
checkpoint_name: Custom name for the checkpoint file.
|
||||
step_counter: A `tf.Variable` instance for checking the current step
|
||||
counter value, in case users want to save checkpoints every N steps.
|
||||
checkpoint_interval: An integer, indicates the minimum step interval
|
||||
between two checkpoints.
|
||||
init_fn: Callable. A function to do customized intialization if no
|
||||
checkpoints are in the directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_to_keep` is not a positive integer.
|
||||
@ -584,6 +622,16 @@ class CheckpointManager(object):
|
||||
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||
self._directory = directory
|
||||
self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
|
||||
self._init_fn = init_fn
|
||||
|
||||
if checkpoint_interval is not None:
|
||||
if step_counter is None:
|
||||
raise ValueError("`step_counter` should be passed if "
|
||||
"`checkpoint_interval` is not None.")
|
||||
self._last_checkpoint_step = None
|
||||
self._step_counter = step_counter
|
||||
self._checkpoint_interval = checkpoint_interval
|
||||
|
||||
recovered_state = get_checkpoint_state(directory)
|
||||
current_clock = time.time()
|
||||
self._maybe_delete = collections.OrderedDict()
|
||||
@ -619,6 +667,10 @@ class CheckpointManager(object):
|
||||
def directory(self):
|
||||
return self._directory
|
||||
|
||||
@property
|
||||
def checkpoint_interval(self):
|
||||
return self._checkpoint_interval
|
||||
|
||||
@property
|
||||
def latest_checkpoint(self):
|
||||
"""The prefix of the most recent checkpoint in `directory`.
|
||||
@ -690,7 +742,12 @@ class CheckpointManager(object):
|
||||
"""
|
||||
return self._checkpoint_prefix
|
||||
|
||||
def save(self, checkpoint_number=None):
|
||||
@property
|
||||
def checkpoint(self):
|
||||
"""Returns the `tf.train.Checkpoint` object."""
|
||||
return self._checkpoint
|
||||
|
||||
def save(self, checkpoint_number=None, check_interval=True):
|
||||
"""Creates a new checkpoint and manages it.
|
||||
|
||||
Args:
|
||||
@ -700,11 +757,27 @@ class CheckpointManager(object):
|
||||
`checkpoint_number` is provided, `save_counter` is still incremented. A
|
||||
user-provided `checkpoint_number` is not incremented even if it is a
|
||||
`Variable`.
|
||||
check_interval: An optional boolean. The argument is only effective when
|
||||
`checkpoint_interval` is passed into the manager. If `True`, the manager
|
||||
will only save the checkpoint if the interval between checkpoints is
|
||||
larger than `checkpoint_interval`. Otherwise it will always save the
|
||||
checkpoint unless a checkpoint has already been saved for the current
|
||||
step.
|
||||
|
||||
Returns:
|
||||
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
||||
and `latest_checkpoint` properties.
|
||||
and `latest_checkpoint` properties. `None` if no checkpoint is saved.
|
||||
"""
|
||||
if self._checkpoint_interval is not None:
|
||||
current_step = _evaluate(self._step_counter)
|
||||
if self._last_checkpoint_step is not None:
|
||||
if current_step == self._last_checkpoint_step:
|
||||
return None
|
||||
if check_interval and current_step < (
|
||||
self._last_checkpoint_step + self._checkpoint_interval):
|
||||
return None
|
||||
self._last_checkpoint_step = current_step
|
||||
|
||||
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
|
||||
# slightly with a custom numbering option.
|
||||
if context.executing_eagerly():
|
||||
@ -749,3 +822,31 @@ class CheckpointManager(object):
|
||||
# checkpoints.
|
||||
self._record_state()
|
||||
return save_path
|
||||
|
||||
def restore_or_initialize(self):
|
||||
"""Restore items in `checkpoint` from the latest checkpoint file.
|
||||
|
||||
This method will first try to restore from the most recent checkpoint in
|
||||
`directory`. If no checkpoints exist in `directory`, and `init_fn` is
|
||||
specified, this method will call `init_fn` to do customized
|
||||
initialization. This can be used to support initialization from pretrained
|
||||
models.
|
||||
|
||||
Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
|
||||
a load status object that users can run assertions on
|
||||
(e.g. assert_consumed()). Thus to run assertions, users should directly use
|
||||
`tf.train.Checkpoint.restore()` method.
|
||||
|
||||
Returns:
|
||||
The restored checkpoint path if the lastest checkpoint is found and
|
||||
restored. Otherwise None.
|
||||
"""
|
||||
if self._latest_checkpoint is not None:
|
||||
self._checkpoint.restore(self._latest_checkpoint)
|
||||
if self._checkpoint_interval is not None:
|
||||
self._last_checkpoint_step = _evaluate(self._step_counter)
|
||||
return self._latest_checkpoint
|
||||
|
||||
if self._init_fn is not None:
|
||||
self._init_fn()
|
||||
return None
|
||||
|
@ -571,6 +571,102 @@ class CheckpointManagerTest(test.TestCase):
|
||||
path = manager.save(checkpoint_number=5)
|
||||
self.assertEqual(os.path.basename(path), "ckpt-5")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testRestoreOrInitialize(self):
|
||||
directory = self.get_temp_dir()
|
||||
|
||||
# Create a checkpoint for initializing.
|
||||
init_prefix = os.path.join(directory, "init")
|
||||
init_v = variables.Variable(2.0)
|
||||
init_ckpt = util.Checkpoint(v=init_v)
|
||||
self.evaluate(init_v.initializer)
|
||||
init_path = init_ckpt.save(init_prefix)
|
||||
|
||||
# Create the checkpoint manager.
|
||||
ckpt_dir = os.path.join(directory, "ckpt")
|
||||
v = variables.Variable(1.0)
|
||||
checkpoint = util.Checkpoint(v=v)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint,
|
||||
ckpt_dir,
|
||||
max_to_keep=None,
|
||||
init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
# First call should call `init_fn`.
|
||||
self.assertIsNone(manager.restore_or_initialize())
|
||||
self.assertEqual(2.0, self.evaluate(v))
|
||||
|
||||
# Save a checkpoint and second call should restore from the checkpoints.
|
||||
manager.save()
|
||||
self.assertIsNotNone(manager.restore_or_initialize())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCheckpointInterval(self):
|
||||
v = variables.Variable(1.0)
|
||||
step_counter = variables.Variable(0)
|
||||
self.evaluate([v.initializer, step_counter.initializer])
|
||||
checkpoint = util.Checkpoint(v=v)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint,
|
||||
self.get_temp_dir(),
|
||||
max_to_keep=None,
|
||||
step_counter=step_counter,
|
||||
checkpoint_interval=2)
|
||||
|
||||
# step_counter: 0, save an initial checkpoint.
|
||||
path = manager.save(check_interval=True)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||
|
||||
# step_counter: 1, no checkpoint saved.
|
||||
self.evaluate(step_counter.assign_add(1))
|
||||
path = manager.save(check_interval=True)
|
||||
self.assertIsNone(path)
|
||||
|
||||
# step_counter: 2, checkpoint saved.
|
||||
self.evaluate(step_counter.assign_add(1))
|
||||
path = manager.save(check_interval=True)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||
|
||||
# no checkpoint saved when calling `save` with the same step counter.
|
||||
path = manager.save(check_interval=True)
|
||||
self.assertIsNone(path)
|
||||
|
||||
# step_counter: 3, no checkpoint saved.
|
||||
self.evaluate(step_counter.assign_add(1))
|
||||
path = manager.save(check_interval=True)
|
||||
self.assertIsNone(path)
|
||||
|
||||
# Always save the checkpoint.
|
||||
path = manager.save(check_interval=False)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCheckpointIntervalWithRestore(self):
|
||||
directory = self.get_temp_dir()
|
||||
v = variables.Variable(1.0)
|
||||
step_counter = variables.Variable(0)
|
||||
self.evaluate([v.initializer, step_counter.initializer])
|
||||
|
||||
# Prepare a checkpoint.
|
||||
checkpoint = util.Checkpoint(v=v)
|
||||
checkpoint.save(os.path.join(directory, "ckpt"))
|
||||
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint,
|
||||
directory,
|
||||
max_to_keep=None,
|
||||
step_counter=step_counter,
|
||||
checkpoint_interval=2)
|
||||
|
||||
# Restore from the checkpoint.
|
||||
self.assertIsNotNone(manager.restore_or_initialize())
|
||||
|
||||
# step_counter: 0, no checkpoint saved because it is restored from the
|
||||
# checkpoint with the same step.
|
||||
path = manager.save()
|
||||
self.assertIsNone(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "checkpoint_interval"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "checkpoints"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -16,10 +24,14 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "restore_or_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "checkpoint_interval"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "checkpoints"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -16,10 +24,14 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "restore_or_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user