Extend tf.train.CheckpointManager
functionality:
1. Add a `restore_or_initialize` method into CheckpointManager to support initializing from a checkpoint if possible. 2. Support optionally saving checkpoints based on interval steps. PiperOrigin-RevId: 294816552 Change-Id: I47f4955e75677b02c10b4baddf03d78822dea6ae
This commit is contained in:
parent
8cad6f59e1
commit
73c31e6c97
@ -40,6 +40,13 @@ from tensorflow.python.util import deprecation
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate(tensor):
|
||||||
|
"""Returns the numpy value of a tensor."""
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return tensor.numpy()
|
||||||
|
return ops.get_default_session().run(tensor)
|
||||||
|
|
||||||
|
|
||||||
def _GetCheckpointFilename(save_dir, latest_filename):
|
def _GetCheckpointFilename(save_dir, latest_filename):
|
||||||
"""Returns a filename for storing the CheckpointState.
|
"""Returns a filename for storing the CheckpointState.
|
||||||
|
|
||||||
@ -529,7 +536,10 @@ class CheckpointManager(object):
|
|||||||
directory,
|
directory,
|
||||||
max_to_keep,
|
max_to_keep,
|
||||||
keep_checkpoint_every_n_hours=None,
|
keep_checkpoint_every_n_hours=None,
|
||||||
checkpoint_name="ckpt"):
|
checkpoint_name="ckpt",
|
||||||
|
step_counter=None,
|
||||||
|
checkpoint_interval=None,
|
||||||
|
init_fn=None):
|
||||||
"""Configure a `CheckpointManager` for use in `directory`.
|
"""Configure a `CheckpointManager` for use in `directory`.
|
||||||
|
|
||||||
If a `CheckpointManager` was previously used in `directory`, its
|
If a `CheckpointManager` was previously used in `directory`, its
|
||||||
@ -550,6 +560,28 @@ class CheckpointManager(object):
|
|||||||
`CheckpointManager` instantiated in `directory` (subject to its
|
`CheckpointManager` instantiated in `directory` (subject to its
|
||||||
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
|
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
|
||||||
|
|
||||||
|
`CheckpointManager` can be also used for initializing the model if
|
||||||
|
there is no checkpoints for restoring in `directory`. An example usage is:
|
||||||
|
|
||||||
|
>>> import tempfile
|
||||||
|
|
||||||
|
>>> tmp_dir = tempfile.mkdtemp()
|
||||||
|
>>> checkpoint = tf.train.Checkpoint()
|
||||||
|
>>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
|
||||||
|
|
||||||
|
>>> def init_fn():
|
||||||
|
... # Partially restore the checkpoint from `init_path`.
|
||||||
|
... checkpoint.restore(init_path)
|
||||||
|
|
||||||
|
>>> manager = tf.train.CheckpointManager(
|
||||||
|
... checkpoint,
|
||||||
|
... directory=os.path.join(tmp_dir, 'ckpt'),
|
||||||
|
... max_to_keep=None,
|
||||||
|
... init_fn=init_fn)
|
||||||
|
>>> # `restore_or_initialize` will call `init_fn` if there is no existing
|
||||||
|
>>> # checkpoint in `directory`.
|
||||||
|
>>> manager.restore_or_initialize()
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint: The `tf.train.Checkpoint` instance to save and manage
|
checkpoint: The `tf.train.Checkpoint` instance to save and manage
|
||||||
checkpoints for.
|
checkpoints for.
|
||||||
@ -569,6 +601,12 @@ class CheckpointManager(object):
|
|||||||
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
|
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
|
||||||
default setting of `None` does not preserve any checkpoints in this way.
|
default setting of `None` does not preserve any checkpoints in this way.
|
||||||
checkpoint_name: Custom name for the checkpoint file.
|
checkpoint_name: Custom name for the checkpoint file.
|
||||||
|
step_counter: A `tf.Variable` instance for checking the current step
|
||||||
|
counter value, in case users want to save checkpoints every N steps.
|
||||||
|
checkpoint_interval: An integer, indicates the minimum step interval
|
||||||
|
between two checkpoints.
|
||||||
|
init_fn: Callable. A function to do customized intialization if no
|
||||||
|
checkpoints are in the directory.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `max_to_keep` is not a positive integer.
|
ValueError: If `max_to_keep` is not a positive integer.
|
||||||
@ -584,6 +622,16 @@ class CheckpointManager(object):
|
|||||||
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||||
self._directory = directory
|
self._directory = directory
|
||||||
self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
|
self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
|
||||||
|
self._init_fn = init_fn
|
||||||
|
|
||||||
|
if checkpoint_interval is not None:
|
||||||
|
if step_counter is None:
|
||||||
|
raise ValueError("`step_counter` should be passed if "
|
||||||
|
"`checkpoint_interval` is not None.")
|
||||||
|
self._last_checkpoint_step = None
|
||||||
|
self._step_counter = step_counter
|
||||||
|
self._checkpoint_interval = checkpoint_interval
|
||||||
|
|
||||||
recovered_state = get_checkpoint_state(directory)
|
recovered_state = get_checkpoint_state(directory)
|
||||||
current_clock = time.time()
|
current_clock = time.time()
|
||||||
self._maybe_delete = collections.OrderedDict()
|
self._maybe_delete = collections.OrderedDict()
|
||||||
@ -619,6 +667,10 @@ class CheckpointManager(object):
|
|||||||
def directory(self):
|
def directory(self):
|
||||||
return self._directory
|
return self._directory
|
||||||
|
|
||||||
|
@property
|
||||||
|
def checkpoint_interval(self):
|
||||||
|
return self._checkpoint_interval
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latest_checkpoint(self):
|
def latest_checkpoint(self):
|
||||||
"""The prefix of the most recent checkpoint in `directory`.
|
"""The prefix of the most recent checkpoint in `directory`.
|
||||||
@ -690,7 +742,12 @@ class CheckpointManager(object):
|
|||||||
"""
|
"""
|
||||||
return self._checkpoint_prefix
|
return self._checkpoint_prefix
|
||||||
|
|
||||||
def save(self, checkpoint_number=None):
|
@property
|
||||||
|
def checkpoint(self):
|
||||||
|
"""Returns the `tf.train.Checkpoint` object."""
|
||||||
|
return self._checkpoint
|
||||||
|
|
||||||
|
def save(self, checkpoint_number=None, check_interval=True):
|
||||||
"""Creates a new checkpoint and manages it.
|
"""Creates a new checkpoint and manages it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -700,11 +757,27 @@ class CheckpointManager(object):
|
|||||||
`checkpoint_number` is provided, `save_counter` is still incremented. A
|
`checkpoint_number` is provided, `save_counter` is still incremented. A
|
||||||
user-provided `checkpoint_number` is not incremented even if it is a
|
user-provided `checkpoint_number` is not incremented even if it is a
|
||||||
`Variable`.
|
`Variable`.
|
||||||
|
check_interval: An optional boolean. The argument is only effective when
|
||||||
|
`checkpoint_interval` is passed into the manager. If `True`, the manager
|
||||||
|
will only save the checkpoint if the interval between checkpoints is
|
||||||
|
larger than `checkpoint_interval`. Otherwise it will always save the
|
||||||
|
checkpoint unless a checkpoint has already been saved for the current
|
||||||
|
step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
||||||
and `latest_checkpoint` properties.
|
and `latest_checkpoint` properties. `None` if no checkpoint is saved.
|
||||||
"""
|
"""
|
||||||
|
if self._checkpoint_interval is not None:
|
||||||
|
current_step = _evaluate(self._step_counter)
|
||||||
|
if self._last_checkpoint_step is not None:
|
||||||
|
if current_step == self._last_checkpoint_step:
|
||||||
|
return None
|
||||||
|
if check_interval and current_step < (
|
||||||
|
self._last_checkpoint_step + self._checkpoint_interval):
|
||||||
|
return None
|
||||||
|
self._last_checkpoint_step = current_step
|
||||||
|
|
||||||
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
|
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
|
||||||
# slightly with a custom numbering option.
|
# slightly with a custom numbering option.
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -749,3 +822,31 @@ class CheckpointManager(object):
|
|||||||
# checkpoints.
|
# checkpoints.
|
||||||
self._record_state()
|
self._record_state()
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
|
def restore_or_initialize(self):
|
||||||
|
"""Restore items in `checkpoint` from the latest checkpoint file.
|
||||||
|
|
||||||
|
This method will first try to restore from the most recent checkpoint in
|
||||||
|
`directory`. If no checkpoints exist in `directory`, and `init_fn` is
|
||||||
|
specified, this method will call `init_fn` to do customized
|
||||||
|
initialization. This can be used to support initialization from pretrained
|
||||||
|
models.
|
||||||
|
|
||||||
|
Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
|
||||||
|
a load status object that users can run assertions on
|
||||||
|
(e.g. assert_consumed()). Thus to run assertions, users should directly use
|
||||||
|
`tf.train.Checkpoint.restore()` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The restored checkpoint path if the lastest checkpoint is found and
|
||||||
|
restored. Otherwise None.
|
||||||
|
"""
|
||||||
|
if self._latest_checkpoint is not None:
|
||||||
|
self._checkpoint.restore(self._latest_checkpoint)
|
||||||
|
if self._checkpoint_interval is not None:
|
||||||
|
self._last_checkpoint_step = _evaluate(self._step_counter)
|
||||||
|
return self._latest_checkpoint
|
||||||
|
|
||||||
|
if self._init_fn is not None:
|
||||||
|
self._init_fn()
|
||||||
|
return None
|
||||||
|
@ -571,6 +571,102 @@ class CheckpointManagerTest(test.TestCase):
|
|||||||
path = manager.save(checkpoint_number=5)
|
path = manager.save(checkpoint_number=5)
|
||||||
self.assertEqual(os.path.basename(path), "ckpt-5")
|
self.assertEqual(os.path.basename(path), "ckpt-5")
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testRestoreOrInitialize(self):
|
||||||
|
directory = self.get_temp_dir()
|
||||||
|
|
||||||
|
# Create a checkpoint for initializing.
|
||||||
|
init_prefix = os.path.join(directory, "init")
|
||||||
|
init_v = variables.Variable(2.0)
|
||||||
|
init_ckpt = util.Checkpoint(v=init_v)
|
||||||
|
self.evaluate(init_v.initializer)
|
||||||
|
init_path = init_ckpt.save(init_prefix)
|
||||||
|
|
||||||
|
# Create the checkpoint manager.
|
||||||
|
ckpt_dir = os.path.join(directory, "ckpt")
|
||||||
|
v = variables.Variable(1.0)
|
||||||
|
checkpoint = util.Checkpoint(v=v)
|
||||||
|
manager = checkpoint_management.CheckpointManager(
|
||||||
|
checkpoint,
|
||||||
|
ckpt_dir,
|
||||||
|
max_to_keep=None,
|
||||||
|
init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
|
||||||
|
self.evaluate(v.initializer)
|
||||||
|
|
||||||
|
# First call should call `init_fn`.
|
||||||
|
self.assertIsNone(manager.restore_or_initialize())
|
||||||
|
self.assertEqual(2.0, self.evaluate(v))
|
||||||
|
|
||||||
|
# Save a checkpoint and second call should restore from the checkpoints.
|
||||||
|
manager.save()
|
||||||
|
self.assertIsNotNone(manager.restore_or_initialize())
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testCheckpointInterval(self):
|
||||||
|
v = variables.Variable(1.0)
|
||||||
|
step_counter = variables.Variable(0)
|
||||||
|
self.evaluate([v.initializer, step_counter.initializer])
|
||||||
|
checkpoint = util.Checkpoint(v=v)
|
||||||
|
manager = checkpoint_management.CheckpointManager(
|
||||||
|
checkpoint,
|
||||||
|
self.get_temp_dir(),
|
||||||
|
max_to_keep=None,
|
||||||
|
step_counter=step_counter,
|
||||||
|
checkpoint_interval=2)
|
||||||
|
|
||||||
|
# step_counter: 0, save an initial checkpoint.
|
||||||
|
path = manager.save(check_interval=True)
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||||
|
|
||||||
|
# step_counter: 1, no checkpoint saved.
|
||||||
|
self.evaluate(step_counter.assign_add(1))
|
||||||
|
path = manager.save(check_interval=True)
|
||||||
|
self.assertIsNone(path)
|
||||||
|
|
||||||
|
# step_counter: 2, checkpoint saved.
|
||||||
|
self.evaluate(step_counter.assign_add(1))
|
||||||
|
path = manager.save(check_interval=True)
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||||
|
|
||||||
|
# no checkpoint saved when calling `save` with the same step counter.
|
||||||
|
path = manager.save(check_interval=True)
|
||||||
|
self.assertIsNone(path)
|
||||||
|
|
||||||
|
# step_counter: 3, no checkpoint saved.
|
||||||
|
self.evaluate(step_counter.assign_add(1))
|
||||||
|
path = manager.save(check_interval=True)
|
||||||
|
self.assertIsNone(path)
|
||||||
|
|
||||||
|
# Always save the checkpoint.
|
||||||
|
path = manager.save(check_interval=False)
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(path))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testCheckpointIntervalWithRestore(self):
|
||||||
|
directory = self.get_temp_dir()
|
||||||
|
v = variables.Variable(1.0)
|
||||||
|
step_counter = variables.Variable(0)
|
||||||
|
self.evaluate([v.initializer, step_counter.initializer])
|
||||||
|
|
||||||
|
# Prepare a checkpoint.
|
||||||
|
checkpoint = util.Checkpoint(v=v)
|
||||||
|
checkpoint.save(os.path.join(directory, "ckpt"))
|
||||||
|
|
||||||
|
manager = checkpoint_management.CheckpointManager(
|
||||||
|
checkpoint,
|
||||||
|
directory,
|
||||||
|
max_to_keep=None,
|
||||||
|
step_counter=step_counter,
|
||||||
|
checkpoint_interval=2)
|
||||||
|
|
||||||
|
# Restore from the checkpoint.
|
||||||
|
self.assertIsNotNone(manager.restore_or_initialize())
|
||||||
|
|
||||||
|
# step_counter: 0, no checkpoint saved because it is restored from the
|
||||||
|
# checkpoint with the same step.
|
||||||
|
path = manager.save()
|
||||||
|
self.assertIsNone(path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
|
|||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "checkpoint_interval"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "checkpoints"
|
name: "checkpoints"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -16,10 +24,14 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
|
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "restore_or_initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "save"
|
name: "save"
|
||||||
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
|
|||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "checkpoint_interval"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "checkpoints"
|
name: "checkpoints"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -16,10 +24,14 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
|
argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "restore_or_initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "save"
|
name: "save"
|
||||||
argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user