Add an option to choose the I/O Device for saving and loading models for CheckpointManager.
This option enables saving and restoring models to or from filesystems only accessible from the localhost when using multiple devices. PiperOrigin-RevId: 345770754 Change-Id: I8362fbf1402f2e1cb7fa7290a7401cc873c02a1f
This commit is contained in:
parent
993a74f438
commit
cd39d7a4c8
@ -752,7 +752,7 @@ class CheckpointManager(object):
|
||||
"""Returns the `tf.train.Checkpoint` object."""
|
||||
return self._checkpoint
|
||||
|
||||
def save(self, checkpoint_number=None, check_interval=True):
|
||||
def save(self, checkpoint_number=None, check_interval=True, options=None):
|
||||
"""Creates a new checkpoint and manages it.
|
||||
|
||||
Args:
|
||||
@ -768,6 +768,9 @@ class CheckpointManager(object):
|
||||
larger than `checkpoint_interval`. Otherwise it will always save the
|
||||
checkpoint unless a checkpoint has already been saved for the current
|
||||
step.
|
||||
options: Optional `tf.train.CheckpointOptions` object. This argument only
|
||||
works with TF2 checkpoint objects. For example, options =
|
||||
tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
|
||||
|
||||
Returns:
|
||||
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
||||
@ -809,7 +812,10 @@ class CheckpointManager(object):
|
||||
checkpoint_number = training_util.global_step(
|
||||
sess=session, global_step_tensor=checkpoint_number)
|
||||
prefix = "%s-%d" % (self._prefix, checkpoint_number)
|
||||
save_path = self._checkpoint.write(prefix)
|
||||
if options is None:
|
||||
save_path = self._checkpoint.write(prefix)
|
||||
else:
|
||||
save_path = self._checkpoint.write(prefix, options=options)
|
||||
timestamp = time.time()
|
||||
# If this is an overwritten checkpoint we were previously tracking, delete
|
||||
# and reinsert it to make sure it goes to the end of the queue.
|
||||
|
@ -32,6 +32,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user