diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 5a00e8f56c8..df68d43bb66 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -40,6 +40,13 @@ from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +def _evaluate(tensor): + """Returns the numpy value of a tensor.""" + if context.executing_eagerly(): + return tensor.numpy() + return ops.get_default_session().run(tensor) + + def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. @@ -529,7 +536,10 @@ class CheckpointManager(object): directory, max_to_keep, keep_checkpoint_every_n_hours=None, - checkpoint_name="ckpt"): + checkpoint_name="ckpt", + step_counter=None, + checkpoint_interval=None, + init_fn=None): """Configure a `CheckpointManager` for use in `directory`. If a `CheckpointManager` was previously used in `directory`, its @@ -550,6 +560,28 @@ class CheckpointManager(object): `CheckpointManager` instantiated in `directory` (subject to its `max_to_keep` and `keep_checkpoint_every_n_hours` settings). + `CheckpointManager` can be also used for initializing the model if + there is no checkpoints for restoring in `directory`. An example usage is: + + >>> import tempfile + + >>> tmp_dir = tempfile.mkdtemp() + >>> checkpoint = tf.train.Checkpoint() + >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init')) + + >>> def init_fn(): + ... # Partially restore the checkpoint from `init_path`. + ... checkpoint.restore(init_path) + + >>> manager = tf.train.CheckpointManager( + ... checkpoint, + ... directory=os.path.join(tmp_dir, 'ckpt'), + ... max_to_keep=None, + ... init_fn=init_fn) + >>> # `restore_or_initialize` will call `init_fn` if there is no existing + >>> # checkpoint in `directory`. + >>> manager.restore_or_initialize() + Args: checkpoint: The `tf.train.Checkpoint` instance to save and manage checkpoints for. @@ -569,6 +601,12 @@ class CheckpointManager(object): `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The default setting of `None` does not preserve any checkpoints in this way. checkpoint_name: Custom name for the checkpoint file. + step_counter: A `tf.Variable` instance for checking the current step + counter value, in case users want to save checkpoints every N steps. + checkpoint_interval: An integer, indicates the minimum step interval + between two checkpoints. + init_fn: Callable. A function to do customized intialization if no + checkpoints are in the directory. Raises: ValueError: If `max_to_keep` is not a positive integer. @@ -584,6 +622,16 @@ class CheckpointManager(object): self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self._directory = directory self._checkpoint_prefix = os.path.join(directory, checkpoint_name) + self._init_fn = init_fn + + if checkpoint_interval is not None: + if step_counter is None: + raise ValueError("`step_counter` should be passed if " + "`checkpoint_interval` is not None.") + self._last_checkpoint_step = None + self._step_counter = step_counter + self._checkpoint_interval = checkpoint_interval + recovered_state = get_checkpoint_state(directory) current_clock = time.time() self._maybe_delete = collections.OrderedDict() @@ -619,6 +667,10 @@ class CheckpointManager(object): def directory(self): return self._directory + @property + def checkpoint_interval(self): + return self._checkpoint_interval + @property def latest_checkpoint(self): """The prefix of the most recent checkpoint in `directory`. @@ -690,7 +742,12 @@ class CheckpointManager(object): """ return self._checkpoint_prefix - def save(self, checkpoint_number=None): + @property + def checkpoint(self): + """Returns the `tf.train.Checkpoint` object.""" + return self._checkpoint + + def save(self, checkpoint_number=None, check_interval=True): """Creates a new checkpoint and manages it. Args: @@ -700,11 +757,27 @@ class CheckpointManager(object): `checkpoint_number` is provided, `save_counter` is still incremented. A user-provided `checkpoint_number` is not incremented even if it is a `Variable`. + check_interval: An optional boolean. The argument is only effective when + `checkpoint_interval` is passed into the manager. If `True`, the manager + will only save the checkpoint if the interval between checkpoints is + larger than `checkpoint_interval`. Otherwise it will always save the + checkpoint unless a checkpoint has already been saved for the current + step. Returns: The path to the new checkpoint. It is also recorded in the `checkpoints` - and `latest_checkpoint` properties. + and `latest_checkpoint` properties. `None` if no checkpoint is saved. """ + if self._checkpoint_interval is not None: + current_step = _evaluate(self._step_counter) + if self._last_checkpoint_step is not None: + if current_step == self._last_checkpoint_step: + return None + if check_interval and current_step < ( + self._last_checkpoint_step + self._checkpoint_interval): + return None + self._last_checkpoint_step = current_step + # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge # slightly with a custom numbering option. if context.executing_eagerly(): @@ -749,3 +822,31 @@ class CheckpointManager(object): # checkpoints. self._record_state() return save_path + + def restore_or_initialize(self): + """Restore items in `checkpoint` from the latest checkpoint file. + + This method will first try to restore from the most recent checkpoint in + `directory`. If no checkpoints exist in `directory`, and `init_fn` is + specified, this method will call `init_fn` to do customized + initialization. This can be used to support initialization from pretrained + models. + + Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return + a load status object that users can run assertions on + (e.g. assert_consumed()). Thus to run assertions, users should directly use + `tf.train.Checkpoint.restore()` method. + + Returns: + The restored checkpoint path if the lastest checkpoint is found and + restored. Otherwise None. + """ + if self._latest_checkpoint is not None: + self._checkpoint.restore(self._latest_checkpoint) + if self._checkpoint_interval is not None: + self._last_checkpoint_step = _evaluate(self._step_counter) + return self._latest_checkpoint + + if self._init_fn is not None: + self._init_fn() + return None diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 4c40945cb15..34666e32ab6 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -571,6 +571,102 @@ class CheckpointManagerTest(test.TestCase): path = manager.save(checkpoint_number=5) self.assertEqual(os.path.basename(path), "ckpt-5") + @test_util.run_in_graph_and_eager_modes + def testRestoreOrInitialize(self): + directory = self.get_temp_dir() + + # Create a checkpoint for initializing. + init_prefix = os.path.join(directory, "init") + init_v = variables.Variable(2.0) + init_ckpt = util.Checkpoint(v=init_v) + self.evaluate(init_v.initializer) + init_path = init_ckpt.save(init_prefix) + + # Create the checkpoint manager. + ckpt_dir = os.path.join(directory, "ckpt") + v = variables.Variable(1.0) + checkpoint = util.Checkpoint(v=v) + manager = checkpoint_management.CheckpointManager( + checkpoint, + ckpt_dir, + max_to_keep=None, + init_fn=lambda: checkpoint.restore(init_path).run_restore_ops()) + self.evaluate(v.initializer) + + # First call should call `init_fn`. + self.assertIsNone(manager.restore_or_initialize()) + self.assertEqual(2.0, self.evaluate(v)) + + # Save a checkpoint and second call should restore from the checkpoints. + manager.save() + self.assertIsNotNone(manager.restore_or_initialize()) + + @test_util.run_in_graph_and_eager_modes + def testCheckpointInterval(self): + v = variables.Variable(1.0) + step_counter = variables.Variable(0) + self.evaluate([v.initializer, step_counter.initializer]) + checkpoint = util.Checkpoint(v=v) + manager = checkpoint_management.CheckpointManager( + checkpoint, + self.get_temp_dir(), + max_to_keep=None, + step_counter=step_counter, + checkpoint_interval=2) + + # step_counter: 0, save an initial checkpoint. + path = manager.save(check_interval=True) + self.assertTrue(checkpoint_management.checkpoint_exists(path)) + + # step_counter: 1, no checkpoint saved. + self.evaluate(step_counter.assign_add(1)) + path = manager.save(check_interval=True) + self.assertIsNone(path) + + # step_counter: 2, checkpoint saved. + self.evaluate(step_counter.assign_add(1)) + path = manager.save(check_interval=True) + self.assertTrue(checkpoint_management.checkpoint_exists(path)) + + # no checkpoint saved when calling `save` with the same step counter. + path = manager.save(check_interval=True) + self.assertIsNone(path) + + # step_counter: 3, no checkpoint saved. + self.evaluate(step_counter.assign_add(1)) + path = manager.save(check_interval=True) + self.assertIsNone(path) + + # Always save the checkpoint. + path = manager.save(check_interval=False) + self.assertTrue(checkpoint_management.checkpoint_exists(path)) + + @test_util.run_in_graph_and_eager_modes + def testCheckpointIntervalWithRestore(self): + directory = self.get_temp_dir() + v = variables.Variable(1.0) + step_counter = variables.Variable(0) + self.evaluate([v.initializer, step_counter.initializer]) + + # Prepare a checkpoint. + checkpoint = util.Checkpoint(v=v) + checkpoint.save(os.path.join(directory, "ckpt")) + + manager = checkpoint_management.CheckpointManager( + checkpoint, + directory, + max_to_keep=None, + step_counter=step_counter, + checkpoint_interval=2) + + # Restore from the checkpoint. + self.assertIsNotNone(manager.restore_or_initialize()) + + # step_counter: 0, no checkpoint saved because it is restored from the + # checkpoint with the same step. + path = manager.save() + self.assertIsNone(path) + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt index 86e25d86d53..6ab4e1c085a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt @@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager" tf_class { is_instance: "" is_instance: "" + member { + name: "checkpoint" + mtype: "" + } + member { + name: "checkpoint_interval" + mtype: "" + } member { name: "checkpoints" mtype: "" @@ -16,10 +24,14 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], " + argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "restore_or_initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { name: "save" - argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt index 86e25d86d53..6ab4e1c085a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt @@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager" tf_class { is_instance: "" is_instance: "" + member { + name: "checkpoint" + mtype: "" + } + member { + name: "checkpoint_interval" + mtype: "" + } member { name: "checkpoints" mtype: "" @@ -16,10 +24,14 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], " + argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "restore_or_initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { name: "save" - argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " } }