Add an option to choose the I/O Device for saving and loading models for CheckpointManager.

This option enables saving and restoring models to or from filesystems only
accessible from the localhost when using multiple devices.

PiperOrigin-RevId: 345770754
Change-Id: I8362fbf1402f2e1cb7fa7290a7401cc873c02a1f
This commit is contained in:
A. Unique TensorFlower 2020-12-04 15:37:28 -08:00 committed by TensorFlower Gardener
parent 993a74f438
commit cd39d7a4c8
3 changed files with 10 additions and 4 deletions

View File

@ -752,7 +752,7 @@ class CheckpointManager(object):
"""Returns the `tf.train.Checkpoint` object."""
return self._checkpoint
def save(self, checkpoint_number=None, check_interval=True):
def save(self, checkpoint_number=None, check_interval=True, options=None):
"""Creates a new checkpoint and manages it.
Args:
@ -768,6 +768,9 @@ class CheckpointManager(object):
larger than `checkpoint_interval`. Otherwise it will always save the
checkpoint unless a checkpoint has already been saved for the current
step.
options: Optional `tf.train.CheckpointOptions` object. This argument only
works with TF2 checkpoint objects. For example, options =
tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
Returns:
The path to the new checkpoint. It is also recorded in the `checkpoints`
@ -809,7 +812,10 @@ class CheckpointManager(object):
checkpoint_number = training_util.global_step(
sess=session, global_step_tensor=checkpoint_number)
prefix = "%s-%d" % (self._prefix, checkpoint_number)
if options is None:
save_path = self._checkpoint.write(prefix)
else:
save_path = self._checkpoint.write(prefix, options=options)
timestamp = time.time()
# If this is an overwritten checkpoint we were previously tracking, delete
# and reinsert it to make sure it goes to the end of the queue.

View File

@ -32,6 +32,6 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
}

View File

@ -32,6 +32,6 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
}