From 73c31e6c9783cc02c5b21e0e3c6d814a53410ff4 Mon Sep 17 00:00:00 2001
From: Ruoxin Sang <rxsang@google.com>
Date: Wed, 12 Feb 2020 19:34:34 -0800
Subject: [PATCH] Extend `tf.train.CheckpointManager` functionality: 1. Add a
 `restore_or_initialize` method into CheckpointManager to support initializing
 from a checkpoint if possible. 2. Support optionally saving checkpoints based
 on interval steps.

PiperOrigin-RevId: 294816552
Change-Id: I47f4955e75677b02c10b4baddf03d78822dea6ae
---
 .../python/training/checkpoint_management.py  | 107 +++++++++++++++++-
 .../training/checkpoint_management_test.py    |  96 ++++++++++++++++
 ...tensorflow.train.-checkpoint-manager.pbtxt |  16 ++-
 ...tensorflow.train.-checkpoint-manager.pbtxt |  16 ++-
 4 files changed, 228 insertions(+), 7 deletions(-)

diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 5a00e8f56c8..df68d43bb66 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -40,6 +40,13 @@ from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
 
+def _evaluate(tensor):
+  """Returns the numpy value of a tensor."""
+  if context.executing_eagerly():
+    return tensor.numpy()
+  return ops.get_default_session().run(tensor)
+
+
 def _GetCheckpointFilename(save_dir, latest_filename):
   """Returns a filename for storing the CheckpointState.
 
@@ -529,7 +536,10 @@ class CheckpointManager(object):
                directory,
                max_to_keep,
                keep_checkpoint_every_n_hours=None,
-               checkpoint_name="ckpt"):
+               checkpoint_name="ckpt",
+               step_counter=None,
+               checkpoint_interval=None,
+               init_fn=None):
     """Configure a `CheckpointManager` for use in `directory`.
 
     If a `CheckpointManager` was previously used in `directory`, its
@@ -550,6 +560,28 @@ class CheckpointManager(object):
     `CheckpointManager` instantiated in `directory` (subject to its
     `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
 
+    `CheckpointManager` can be also used for initializing the model if
+    there is no checkpoints for restoring in `directory`. An example usage is:
+
+    >>> import tempfile
+
+    >>> tmp_dir = tempfile.mkdtemp()
+    >>> checkpoint = tf.train.Checkpoint()
+    >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
+
+    >>> def init_fn():
+    ...   # Partially restore the checkpoint from `init_path`.
+    ...   checkpoint.restore(init_path)
+
+    >>> manager = tf.train.CheckpointManager(
+    ...     checkpoint,
+    ...     directory=os.path.join(tmp_dir, 'ckpt'),
+    ...     max_to_keep=None,
+    ...     init_fn=init_fn)
+    >>> # `restore_or_initialize` will call `init_fn` if there is no existing
+    >>> # checkpoint in `directory`.
+    >>> manager.restore_or_initialize()
+
     Args:
       checkpoint: The `tf.train.Checkpoint` instance to save and manage
         checkpoints for.
@@ -569,6 +601,12 @@ class CheckpointManager(object):
         `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
         default setting of `None` does not preserve any checkpoints in this way.
       checkpoint_name: Custom name for the checkpoint file.
+      step_counter: A `tf.Variable` instance for checking the current step
+        counter value, in case users want to save checkpoints every N steps.
+      checkpoint_interval: An integer, indicates the minimum step interval
+        between two checkpoints.
+      init_fn: Callable. A function to do customized intialization if no
+        checkpoints are in the directory.
 
     Raises:
       ValueError: If `max_to_keep` is not a positive integer.
@@ -584,6 +622,16 @@ class CheckpointManager(object):
     self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
     self._directory = directory
     self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
+    self._init_fn = init_fn
+
+    if checkpoint_interval is not None:
+      if step_counter is None:
+        raise ValueError("`step_counter` should be passed if "
+                         "`checkpoint_interval` is not None.")
+      self._last_checkpoint_step = None
+      self._step_counter = step_counter
+    self._checkpoint_interval = checkpoint_interval
+
     recovered_state = get_checkpoint_state(directory)
     current_clock = time.time()
     self._maybe_delete = collections.OrderedDict()
@@ -619,6 +667,10 @@ class CheckpointManager(object):
   def directory(self):
     return self._directory
 
+  @property
+  def checkpoint_interval(self):
+    return self._checkpoint_interval
+
   @property
   def latest_checkpoint(self):
     """The prefix of the most recent checkpoint in `directory`.
@@ -690,7 +742,12 @@ class CheckpointManager(object):
     """
     return self._checkpoint_prefix
 
-  def save(self, checkpoint_number=None):
+  @property
+  def checkpoint(self):
+    """Returns the `tf.train.Checkpoint` object."""
+    return self._checkpoint
+
+  def save(self, checkpoint_number=None, check_interval=True):
     """Creates a new checkpoint and manages it.
 
     Args:
@@ -700,11 +757,27 @@ class CheckpointManager(object):
         `checkpoint_number` is provided, `save_counter` is still incremented. A
         user-provided `checkpoint_number` is not incremented even if it is a
         `Variable`.
+      check_interval: An optional boolean. The argument is only effective when
+        `checkpoint_interval` is passed into the manager. If `True`, the manager
+        will only save the checkpoint if the interval between checkpoints is
+        larger than `checkpoint_interval`. Otherwise it will always save the
+        checkpoint unless a checkpoint has already been saved for the current
+        step.
 
     Returns:
       The path to the new checkpoint. It is also recorded in the `checkpoints`
-      and `latest_checkpoint` properties.
+      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
     """
+    if self._checkpoint_interval is not None:
+      current_step = _evaluate(self._step_counter)
+      if self._last_checkpoint_step is not None:
+        if current_step == self._last_checkpoint_step:
+          return None
+        if check_interval and current_step < (
+            self._last_checkpoint_step + self._checkpoint_interval):
+          return None
+      self._last_checkpoint_step = current_step
+
     # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
     # slightly with a custom numbering option.
     if context.executing_eagerly():
@@ -749,3 +822,31 @@ class CheckpointManager(object):
     # checkpoints.
     self._record_state()
     return save_path
+
+  def restore_or_initialize(self):
+    """Restore items in `checkpoint` from the latest checkpoint file.
+
+    This method will first try to restore from the most recent checkpoint in
+    `directory`. If no checkpoints exist in `directory`, and `init_fn` is
+    specified, this method will call `init_fn` to do customized
+    initialization. This can be used to support initialization from pretrained
+    models.
+
+    Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
+    a load status object that users can run assertions on
+    (e.g. assert_consumed()). Thus to run assertions, users should directly use
+    `tf.train.Checkpoint.restore()` method.
+
+    Returns:
+      The restored checkpoint path if the lastest checkpoint is found and
+      restored. Otherwise None.
+    """
+    if self._latest_checkpoint is not None:
+      self._checkpoint.restore(self._latest_checkpoint)
+      if self._checkpoint_interval is not None:
+        self._last_checkpoint_step = _evaluate(self._step_counter)
+      return self._latest_checkpoint
+
+    if self._init_fn is not None:
+      self._init_fn()
+    return None
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 4c40945cb15..34666e32ab6 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -571,6 +571,102 @@ class CheckpointManagerTest(test.TestCase):
     path = manager.save(checkpoint_number=5)
     self.assertEqual(os.path.basename(path), "ckpt-5")
 
+  @test_util.run_in_graph_and_eager_modes
+  def testRestoreOrInitialize(self):
+    directory = self.get_temp_dir()
+
+    # Create a checkpoint for initializing.
+    init_prefix = os.path.join(directory, "init")
+    init_v = variables.Variable(2.0)
+    init_ckpt = util.Checkpoint(v=init_v)
+    self.evaluate(init_v.initializer)
+    init_path = init_ckpt.save(init_prefix)
+
+    # Create the checkpoint manager.
+    ckpt_dir = os.path.join(directory, "ckpt")
+    v = variables.Variable(1.0)
+    checkpoint = util.Checkpoint(v=v)
+    manager = checkpoint_management.CheckpointManager(
+        checkpoint,
+        ckpt_dir,
+        max_to_keep=None,
+        init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
+    self.evaluate(v.initializer)
+
+    # First call should call `init_fn`.
+    self.assertIsNone(manager.restore_or_initialize())
+    self.assertEqual(2.0, self.evaluate(v))
+
+    # Save a checkpoint and second call should restore from the checkpoints.
+    manager.save()
+    self.assertIsNotNone(manager.restore_or_initialize())
+
+  @test_util.run_in_graph_and_eager_modes
+  def testCheckpointInterval(self):
+    v = variables.Variable(1.0)
+    step_counter = variables.Variable(0)
+    self.evaluate([v.initializer, step_counter.initializer])
+    checkpoint = util.Checkpoint(v=v)
+    manager = checkpoint_management.CheckpointManager(
+        checkpoint,
+        self.get_temp_dir(),
+        max_to_keep=None,
+        step_counter=step_counter,
+        checkpoint_interval=2)
+
+    # step_counter: 0, save an initial checkpoint.
+    path = manager.save(check_interval=True)
+    self.assertTrue(checkpoint_management.checkpoint_exists(path))
+
+    # step_counter: 1, no checkpoint saved.
+    self.evaluate(step_counter.assign_add(1))
+    path = manager.save(check_interval=True)
+    self.assertIsNone(path)
+
+    # step_counter: 2, checkpoint saved.
+    self.evaluate(step_counter.assign_add(1))
+    path = manager.save(check_interval=True)
+    self.assertTrue(checkpoint_management.checkpoint_exists(path))
+
+    # no checkpoint saved when calling `save` with the same step counter.
+    path = manager.save(check_interval=True)
+    self.assertIsNone(path)
+
+    # step_counter: 3, no checkpoint saved.
+    self.evaluate(step_counter.assign_add(1))
+    path = manager.save(check_interval=True)
+    self.assertIsNone(path)
+
+    # Always save the checkpoint.
+    path = manager.save(check_interval=False)
+    self.assertTrue(checkpoint_management.checkpoint_exists(path))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testCheckpointIntervalWithRestore(self):
+    directory = self.get_temp_dir()
+    v = variables.Variable(1.0)
+    step_counter = variables.Variable(0)
+    self.evaluate([v.initializer, step_counter.initializer])
+
+    # Prepare a checkpoint.
+    checkpoint = util.Checkpoint(v=v)
+    checkpoint.save(os.path.join(directory, "ckpt"))
+
+    manager = checkpoint_management.CheckpointManager(
+        checkpoint,
+        directory,
+        max_to_keep=None,
+        step_counter=step_counter,
+        checkpoint_interval=2)
+
+    # Restore from the checkpoint.
+    self.assertIsNotNone(manager.restore_or_initialize())
+
+    # step_counter: 0, no checkpoint saved because it is restored from the
+    # checkpoint with the same step.
+    path = manager.save()
+    self.assertIsNone(path)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt
index 86e25d86d53..6ab4e1c085a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt
@@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
 tf_class {
   is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "checkpoint_interval"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "checkpoints"
     mtype: "<type \'property\'>"
@@ -16,10 +24,14 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
+    argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "restore_or_initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt
index 86e25d86d53..6ab4e1c085a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt
@@ -2,6 +2,14 @@ path: "tensorflow.train.CheckpointManager"
 tf_class {
   is_instance: "<class \'tensorflow.python.training.checkpoint_management.CheckpointManager\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "checkpoint_interval"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "checkpoints"
     mtype: "<type \'property\'>"
@@ -16,10 +24,14 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\'], "
+    argspec: "args=[\'self\', \'checkpoint\', \'directory\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'checkpoint_name\', \'step_counter\', \'checkpoint_interval\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'ckpt\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "restore_or_initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'checkpoint_number\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
   }
 }