From cd39d7a4c84468d3058dc6911a7119a39c1ee3f3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Dec 2020 15:37:28 -0800 Subject: [PATCH] Add an option to choose the I/O Device for saving and loading models for CheckpointManager. This option enables saving and restoring models to or from filesystems only accessible from the localhost when using multiple devices. PiperOrigin-RevId: 345770754 Change-Id: I8362fbf1402f2e1cb7fa7290a7401cc873c02a1f --- tensorflow/python/training/checkpoint_management.py | 10 ++++++++-- .../v1/tensorflow.train.-checkpoint-manager.pbtxt | 2 +- .../v2/tensorflow.train.-checkpoint-manager.pbtxt | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 4387b8ec303..e1e9dee5e84 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -752,7 +752,7 @@ class CheckpointManager(object): """Returns the `tf.train.Checkpoint` object.""" return self._checkpoint - def save(self, checkpoint_number=None, check_interval=True): + def save(self, checkpoint_number=None, check_interval=True, options=None): """Creates a new checkpoint and manages it. Args: @@ -768,6 +768,9 @@ class CheckpointManager(object): larger than `checkpoint_interval`. Otherwise it will always save the checkpoint unless a checkpoint has already been saved for the current step. + options: Optional `tf.train.CheckpointOptions` object. This argument only + works with TF2 checkpoint objects. For example, options = + tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') Returns: The path to the new checkpoint. It is also recorded in the `checkpoints` @@ -809,7 +812,10 @@ class CheckpointManager(object): checkpoint_number = training_util.global_step( sess=session, global_step_tensor=checkpoint_number) prefix = "%s-%d" % (self._prefix, checkpoint_number) - save_path = self._checkpoint.write(prefix) + if options is None: + save_path = self._checkpoint.write(prefix) + else: + save_path = self._checkpoint.write(prefix, options=options) timestamp = time.time() # If this is an overwritten checkpoint we were previously tracking, delete # and reinsert it to make sure it goes to the end of the queue. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt index 6ab4e1c085a..985b16f6420 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-manager.pbtxt @@ -32,6 +32,6 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt index 6ab4e1c085a..985b16f6420 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-manager.pbtxt @@ -32,6 +32,6 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + argspec: "args=[\'self\', \'checkpoint_number\', \'check_interval\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], " } }