Add an option to choose the I/O Device for saving and loading models for CheckpointManager.
This option enables saving and restoring models to or from filesystems only accessible from the localhost when using multiple devices. PiperOrigin-RevId: 345770754 Change-Id: I8362fbf1402f2e1cb7fa7290a7401cc873c02a1f
This commit is contained in:
parent
993a74f438
commit
cd39d7a4c8
@ -752,7 +752,7 @@ class CheckpointManager(object):
|
|||||||
"""Returns the `tf.train.Checkpoint` object."""
|
"""Returns the `tf.train.Checkpoint` object."""
|
||||||
return self._checkpoint
|
return self._checkpoint
|
||||||
|
|
||||||
def save(self, checkpoint_number=None, check_interval=True):
|
def save(self, checkpoint_number=None, check_interval=True, options=None):
|
||||||
"""Creates a new checkpoint and manages it.
|
"""Creates a new checkpoint and manages it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -768,6 +768,9 @@ class CheckpointManager(object):
|
|||||||
larger than `checkpoint_interval`. Otherwise it will always save the
|
larger than `checkpoint_interval`. Otherwise it will always save the
|
||||||
checkpoint unless a checkpoint has already been saved for the current
|
checkpoint unless a checkpoint has already been saved for the current
|
||||||
step.
|
step.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object. This argument only
|
||||||
|
works with TF2 checkpoint objects. For example, options =
|
||||||
|
tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
||||||
@ -809,7 +812,10 @@ class CheckpointManager(object):
|
|||||||
checkpoint_number = training_util.global_step(
|
checkpoint_number = training_util.global_step(
|
||||||
sess=session, global_step_tensor=checkpoint_number)
|
sess=session, global_step_tensor=checkpoint_number)
|
||||||
prefix = "%s-%d" % (self._prefix, checkpoint_number)
|
prefix = "%s-%d" % (self._prefix, checkpoint_number)
|
||||||
save_path = self._checkpoint.write(prefix)
|
if options is None:
|
||||||
|
save_path = self._checkpoint.write(prefix)
|
||||||
|
else:
|
||||||
|
save_path = self._checkpoint.write(prefix, options=options)
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
# If this is an overwritten checkpoint we were previously tracking, delete
|
# If this is an overwritten checkpoint we were previously tracking, delete
|
||||||
# and reinsert it to make sure it goes to the end of the queue.
|
# and reinsert it to make sure it goes to the end of the queue.
|
||||||
|
@ -32,6 +32,6 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "save"
|
name: "save"
|
||||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,6 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "save"
|
name: "save"
|
||||||
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user