diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index f99340e6bad..2e5db7edd27 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -310,6 +310,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", + "//tensorflow/python/training/saving:checkpoint_options", "//tensorflow/python/training/saving:functional_saver", "//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking:base", diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index a5d6353280c..9553fb5b196 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -52,6 +52,7 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import signature_serialization from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils_impl +from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import graph_view @@ -941,6 +942,7 @@ def save(obj, export_dir, signatures=None, options=None): May not be called from within a function body. @end_compatibility """ + options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. @@ -954,7 +956,10 @@ def save(obj, export_dir, signatures=None, options=None): # Write the checkpoint, copy assets into the assets directory, and write out # the SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) - object_saver.save(utils_impl.get_variables_path(export_dir)) + ckpt_options = checkpoint_options.CheckpointOptions( + experimental_io_device=options.experimental_io_device) + object_saver.save(utils_impl.get_variables_path(export_dir), + options=ckpt_options) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) # Note that this needs to be the last file operation when saving the @@ -976,6 +981,7 @@ def save(obj, export_dir, signatures=None, options=None): def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto to a file.""" + options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) meta_graph_def, exported_graph, _, _ = _build_meta_graph( obj, export_dir, signatures, options) @@ -1001,7 +1007,6 @@ def _build_meta_graph(obj, export_dir, signatures, options, if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) - options = options or save_options.SaveOptions() meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() checkpoint_graph_view = _AugmentedGraphView(obj) diff --git a/tensorflow/python/saved_model/save_options.py b/tensorflow/python/saved_model/save_options.py index a8528c002e3..748ae7600eb 100644 --- a/tensorflow/python/saved_model/save_options.py +++ b/tensorflow/python/saved_model/save_options.py @@ -33,12 +33,14 @@ class SaveOptions(object): """ # Define object attributes in __slots__ for improved memory and performance. - __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases") + __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases", + "experimental_io_device") def __init__(self, namespace_whitelist=None, save_debug_info=False, - function_aliases=None): + function_aliases=None, + experimental_io_device=None): """Creates an object that stores options for SavedModel saving. Args: @@ -46,16 +48,15 @@ class SaveOptions(object): when saving a model. Saving an object that uses namespaced ops must explicitly add all namespaces to the whitelist. The namespaced ops must be registered into the framework when loading the SavedModel. - save_debug_info: Boolean indicating whether debug information is saved. - If True, then a debug/saved_model_debug_info.pb file will be written - with the contents of a GraphDebugInfo binary protocol buffer containing - stack trace information for all ops and functions that are saved. + save_debug_info: Boolean indicating whether debug information is saved. If + True, then a debug/saved_model_debug_info.pb file will be written with + the contents of a GraphDebugInfo binary protocol buffer containing stack + trace information for all ops and functions that are saved. function_aliases: Python dict. Mapping from string to object returned by - @tf.function. - A single tf.function can generate many ConcreteFunctions. If a - downstream tool wants to refer to all concrete functions generated by a - single tf.function you can use the `function_aliases` argument to store - a map from the alias name to all concrete function names. + @tf.function. A single tf.function can generate many ConcreteFunctions. + If a downstream tool wants to refer to all concrete functions generated + by a single tf.function you can use the `function_aliases` argument to + store a map from the alias name to all concrete function names. E.g. ```python class MyModel: @@ -77,11 +78,21 @@ class SaveOptions(object): }) tf.saved_model.save(model, export_dir, signatures, options) ``` + experimental_io_device: string. Applies in a distributed setting. + Tensorflow device to use to access the filesystem. If `None` (default) + then for each variable the filesystem is accessed from the CPU:0 device + of the host where that variable is assigned. If specified, the + filesystem is instead accessed from that device for all variables. + + This is for example useful if you want to save to a local directory, + such as "/tmp" when running in a distributed setting. In that case pass + a device for the host where the "/tmp" directory is accessible. """ self.namespace_whitelist = _validate_namespace_whitelist( namespace_whitelist) self.save_debug_info = save_debug_info self.function_aliases = function_aliases if function_aliases else dict() + self.experimental_io_device = experimental_io_device def _validate_namespace_whitelist(namespace_whitelist): diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index cae8c4c7c96..09e7296a483 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -577,6 +577,12 @@ class SavingOptionsTest(test.TestCase): self.assertEqual(function_cache[0].name.decode("utf-8"), list(function_aliases.keys())[0]) + def test_accepts_io_device(self): + options = save_options.SaveOptions() + self.assertEqual(None, options.experimental_io_device) + options = save_options.SaveOptions(experimental_io_device="/job:localhost") + self.assertEqual("/job:localhost", options.experimental_io_device) + class AssetTests(test.TestCase): diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD index a8f595f3ac6..670a4c35c6f 100644 --- a/tensorflow/python/training/saving/BUILD +++ b/tensorflow/python/training/saving/BUILD @@ -12,11 +12,20 @@ package( exports_files(["LICENSE"]) +py_library( + name = "checkpoint_options", + srcs = ["checkpoint_options.py"], + deps = [ + "//tensorflow/python:tf_export", + ], +) + py_library( name = "functional_saver", srcs = ["functional_saver.py"], srcs_version = "PY2AND3", deps = [ + ":checkpoint_options", ":saveable_hook", ":saveable_object", ":saveable_object_util", @@ -31,6 +40,7 @@ cuda_py_test( "functional_saver_test.py", ], deps = [ + ":checkpoint_options", ":functional_saver", ":saveable_hook", "//tensorflow/python/eager:test", diff --git a/tensorflow/python/training/saving/checkpoint_options.py b/tensorflow/python/training/saving/checkpoint_options.py new file mode 100644 index 00000000000..92fd679943c --- /dev/null +++ b/tensorflow/python/training/saving/checkpoint_options.py @@ -0,0 +1,58 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Options for saving Checkpoints.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.CheckpointOptions") +class CheckpointOptions(object): + """Options for constructing a Checkpoint. + + Used as the `_options` argument to the `tf.Checkpoint` constructor to adjust + how variables are saved. + + Example: Run IO ops on "localhost" while saving a checkpoint: + + ``` + step = tf.Variable(0, name="step") + checkpoint = tf.Checkpoint(step=step) + options = tf.CheckpointOptions(experimental_io_device="/job:localhost") + checkpoint.save("/tmp/ckpt", options=options) + ``` + """ + + # Define object attributes in __slots__ for improved memory and performance. + __slots__ = ("experimental_io_device",) + + def __init__(self, experimental_io_device=None): + """Creates an object that stores options for a Checkpoint. + + Args: + experimental_io_device: string. Applies in a distributed setting. + Tensorflow device to use to access the filesystem. If `None` (default) + then for each variable the filesystem is accessed from the CPU:0 device + of the host where that variable is assigned. If specified, the + filesystem is instead accessed from that device for all variables. + + This is for example useful if you want to save to a local directory, + such as "/tmp" when running in a distributed setting. In that case pass + a device for the host where the "/tmp" directory is accessible. + """ + self.experimental_io_device = experimental_io_device diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py index d85852dabe6..c4334e096df 100644 --- a/tensorflow/python/training/saving/functional_saver.py +++ b/tensorflow/python/training/saving/functional_saver.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import string_ops +from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import saveable_hook from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util @@ -52,15 +53,17 @@ class _SingleDeviceSaver(object): "Expected a list of SaveableObjects, got %s." % (saveable,)) self._saveable_objects = saveable_objects - def save(self, file_prefix): + def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. + options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ + options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensors = [] tensor_slices = [] @@ -69,19 +72,22 @@ class _SingleDeviceSaver(object): tensor_names.append(spec.name) tensors.append(spec.tensor) tensor_slices.append(spec.slice_spec) - with ops.device("cpu:0"): + save_device = options.experimental_io_device or "cpu:0" + with ops.device(save_device): return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors) - def restore(self, file_prefix): + def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. + options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ + options = options or checkpoint_options.CheckpointOptions() restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: @@ -91,7 +97,8 @@ class _SingleDeviceSaver(object): saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) - with ops.device("cpu:0"): + restore_device = options.experimental_io_device or "cpu:0" + with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( @@ -190,15 +197,17 @@ class MultiDeviceSaver(object): with ops.control_dependencies(restore_ops.values()): return array_ops.identity(file_prefix) - def save(self, file_prefix): + def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. + options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ + options = options or checkpoint_options.CheckpointOptions() for callback in self._before_save_callbacks: callback() @@ -253,32 +262,37 @@ class MultiDeviceSaver(object): with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but initial # read operations should be placed on the SaveableObject's device. - sharded_saves.append(saver.save(shard_prefix)) + sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): - # Co-locates the merge step with the last device. - with ops.device(saveable_object_util.set_cpu0(last_device)): + # Merge on the io_device if specified, otherwise co-locates the merge op + # with the last device used. + merge_device = (options.experimental_io_device or + saveable_object_util.set_cpu0(last_device)) + with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once merged, # attempts to delete the temporary directory, "_temp". return gen_io_ops.merge_v2_checkpoints( sharded_prefixes, file_prefix, delete_old_dirs=True) - def restore(self, file_prefix): + def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. + options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ + options = options or checkpoint_options.CheckpointOptions() restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): - restore_ops.update(saver.restore(file_prefix)) + restore_ops.update(saver.restore(file_prefix, options)) for callback in self._after_restore_callbacks: callback() diff --git a/tensorflow/python/training/saving/functional_saver_test.py b/tensorflow/python/training/saving/functional_saver_test.py index dfa2023a091..7db32ff72d7 100644 --- a/tensorflow/python/training/saving/functional_saver_test.py +++ b/tensorflow/python/training/saving/functional_saver_test.py @@ -20,21 +20,37 @@ from __future__ import print_function import os -from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.eager import wrap_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import gfile +from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import saveable_hook from tensorflow.python.training.saving import saveable_object_util +LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0" + class SaverTest(test.TestCase): + def setUp(self): + super(SaverTest, self).setUp() + cpus = config.list_physical_devices("CPU") + # Set 3 virtual CPUs + config.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration() + ]) + self.local_options = checkpoint_options.CheckpointOptions( + experimental_io_device=LOCALHOST) + @test_util.run_in_graph_and_eager_modes def test_resource_variable(self): v1 = resource_variable_ops.ResourceVariable(2.) @@ -55,6 +71,33 @@ class SaverTest(test.TestCase): self.evaluate(second_saver.restore(prefix)) self.assertEqual(2., self.evaluate(v2)) + @test_util.run_in_graph_and_eager_modes + def test_resource_variable_use_localhost(self): + v1 = resource_variable_ops.ResourceVariable(2.) + self.evaluate(v1.initializer) + saver = functional_saver._SingleDeviceSaver( + saveable_object_util.saveable_objects_for_op(v1, "x")) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) + self.assertEqual(2, len(gfile.Glob(prefix + "*"))) + self.evaluate(v1.assign(1.)) + self.evaluate(saver.restore(prefix, self.local_options)) + self.assertEqual(2., self.evaluate(v1)) + + v2 = resource_variable_ops.ResourceVariable(3.) + self.evaluate(v2.initializer) + second_saver = functional_saver._SingleDeviceSaver( + saveable_object_util.saveable_objects_for_op(v2, "x")) + self.evaluate(second_saver.restore(prefix, self.local_options)) + self.assertEqual(2., self.evaluate(v2)) + + # In graph mode, verify that the save and restore ops were set to run on + # localhost. + if not context.executing_eagerly(): + for op in ops.get_default_graph().get_operations(): + if op.type in ("SaveV2", "RestoreV2"): + self.assertEqual(LOCALHOST, op.device) + def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( @@ -83,12 +126,7 @@ class SaverTest(test.TestCase): second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2)) - @test_util.run_v1_only( - "Needs an API to setup multiple devices, b/124805129") - # Set up multiple devices when graph building. Before test.main() we configure - # the devices for eager execution. - @test_util.run_in_graph_and_eager_modes( - config=config_pb2.ConfigProto(device_count={"CPU": 3})) + @test_util.run_in_graph_and_eager_modes def test_checkpoint_is_sharded_by_device(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) @@ -99,9 +137,9 @@ class SaverTest(test.TestCase): self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( - list(saveable_object_util.saveable_objects_for_op(v0, "v0")) - + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) - + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) + list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) @@ -113,8 +151,38 @@ class SaverTest(test.TestCase): self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2)) + @test_util.run_in_graph_and_eager_modes + def test_checkpoint_multi_device_using_localhost(self): + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.) + with ops.device("cpu:1"): + v1 = resource_variable_ops.ResourceVariable(1.) + with ops.device("cpu:2"): + v2 = resource_variable_ops.ResourceVariable(2.) -class SaveableHookTest(test.TestCase): + self.evaluate([v0.initializer, v1.initializer, v2.initializer]) + saver = functional_saver.MultiDeviceSaver( + list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) + self.assertEqual(4, len(gfile.Glob(prefix + "*"))) + self.evaluate(v0.assign(-1.)) + self.evaluate(v1.assign(-1.)) + self.evaluate(v2.assign(-1.)) + self.evaluate( + saver.restore(constant_op.constant(prefix), self.local_options)) + self.assertEqual(0., self.evaluate(v0)) + self.assertEqual(1., self.evaluate(v1)) + self.assertEqual(2., self.evaluate(v2)) + + # In graph mode, verify that the save and restore ops were set to run on + # localhost. + if not context.executing_eagerly(): + for op in ops.get_default_graph().get_operations(): + if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): + self.assertEqual(LOCALHOST, op.device) def test_callbacks_run(self): # Use dict because an int would be shadowed inside callback. @@ -144,6 +212,5 @@ class SaveableHookTest(test.TestCase): if __name__ == "__main__": - ops.enable_eager_execution( - config=config_pb2.ConfigProto(device_count={"CPU": 3})) + ops.enable_eager_execution() test.main() diff --git a/tensorflow/python/training/tracking/BUILD b/tensorflow/python/training/tracking/BUILD index 943490218a0..f893e29feab 100644 --- a/tensorflow/python/training/tracking/BUILD +++ b/tensorflow/python/training/tracking/BUILD @@ -150,6 +150,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/keras:backend", + "//tensorflow/python/training/saving:checkpoint_options", "//tensorflow/python/training/saving:functional_saver", "//tensorflow/python/training/saving:saveable_object_util", "@six_archive//:six", @@ -191,6 +192,7 @@ tf_py_test( "//tensorflow/python/keras:engine", "//tensorflow/python/keras/layers", "//tensorflow/python/keras/optimizer_v2", + "//tensorflow/python/training/saving:checkpoint_options", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index 24a28e94031..7b603ed5dc2 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import saver as v1_saver_lib +from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base @@ -168,7 +169,7 @@ class _CheckpointRestoreCoordinator(object): """Holds the status of an object-based checkpoint load.""" def __init__(self, object_graph_proto, save_path, save_path_tensor, - restore_op_cache, graph_view): + restore_op_cache, graph_view, options): """Specify the checkpoint being loaded. Args: @@ -184,7 +185,9 @@ class _CheckpointRestoreCoordinator(object): `restore()` calls. graph_view: A graph_view_lib.ObjectGraphView object for the restored objects. + options: A CheckpointOptions object. """ + self.options = options self.object_graph_proto = object_graph_proto self.restore_uid = ops.uid() # Maps from proto ids to lists of attributes which were in the checkpoint @@ -291,7 +294,7 @@ class _CheckpointRestoreCoordinator(object): ("Saveable keys changed when validating. Got back %s, was " "expecting %s") % (tensor_saveables.keys(), validated_names)) new_restore_ops = functional_saver.MultiDeviceSaver( - validated_saveables).restore(self.save_path_tensor) + validated_saveables).restore(self.save_path_tensor, self.options) if not context.executing_eagerly(): for name, restore_op in sorted(new_restore_ops.items()): restore_ops.append(restore_op) @@ -1113,13 +1116,15 @@ class TrackableSaver(object): def _save_cached_when_graph_building(self, file_prefix, - object_graph_tensor=None): + object_graph_tensor, + options): """Create or retrieve save ops. Args: file_prefix: The prefix for saved checkpoint files. object_graph_tensor: A `Tensor` to which the current object graph will be fed. + options: `CheckpointOptions` object. Returns: A two-element tuple with a filename tensor and a feed_dict of tensors to @@ -1137,14 +1142,15 @@ class TrackableSaver(object): # var_list. or context.executing_eagerly() or ops.inside_function()): saver = functional_saver.MultiDeviceSaver(named_saveable_objects) - save_op = saver.save(file_prefix) + save_op = saver.save(file_prefix, options=options) with ops.device("/cpu:0"): with ops.control_dependencies([save_op]): self._cached_save_operation = array_ops.identity(file_prefix) self._last_save_object_graph = graph_proto return self._cached_save_operation, feed_additions - def save(self, file_prefix, checkpoint_number=None, session=None): + def save(self, file_prefix, checkpoint_number=None, session=None, + options=None): """Save a training checkpoint. The saved checkpoint includes variables created by this object and any @@ -1162,10 +1168,12 @@ class TrackableSaver(object): session: The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used. + options: Optional `tf.train.CheckpointOptions` object. Returns: The full path to the checkpoint. """ + options = options or checkpoint_options.CheckpointOptions() feed_dict = {} use_session = (not context.executing_eagerly() and not ops.inside_function()) @@ -1189,7 +1197,7 @@ class TrackableSaver(object): file_io.recursive_create_dir(os.path.dirname(file_prefix)) save_path, new_feed_additions = self._save_cached_when_graph_building( - file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor) + file_prefix_tensor, object_graph_tensor, options) if new_feed_additions: feed_dict.update(new_feed_additions) if not use_session: @@ -1202,7 +1210,7 @@ class TrackableSaver(object): else: return save_path - def restore(self, save_path): + def restore(self, save_path, options=None): """Restore a training checkpoint. Restores `root_trackable` and any objects that it tracks @@ -1250,6 +1258,7 @@ class TrackableSaver(object): object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables. + options: Optional `tf.train.CheckpointOptions` object. Returns: A load status object, which can be used to make assertions about the @@ -1260,6 +1269,7 @@ class TrackableSaver(object): If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` object is returned which runs restore ops from a name-based saver. """ + options = options or checkpoint_options.CheckpointOptions() if save_path is None: return InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) @@ -1304,7 +1314,8 @@ class TrackableSaver(object): save_path=save_path, save_path_tensor=file_prefix_tensor, restore_op_cache=self._restore_op_cache, - graph_view=self._graph_view) + graph_view=self._graph_view, + options=options) base.CheckpointPosition( checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) load_status = CheckpointLoadStatus( @@ -1736,6 +1747,8 @@ class Checkpoint(tracking.AutoTrackable): checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + # Create a Checkpoint that will manage two objects with trackable state, + # one we name "optimizer" and the other we name "model". checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): @@ -1744,7 +1757,7 @@ class Checkpoint(tracking.AutoTrackable): checkpoint.save(file_prefix=checkpoint_prefix) ``` - `Checkpoint.save` and `Checkpoint.restore` write and read object-based + `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which writes and reads `variable.name` based checkpoints. Object-based checkpointing saves a @@ -1757,7 +1770,7 @@ class Checkpoint(tracking.AutoTrackable): arguments to their constructors, and each dependency is given a name that is identical to the name of the keyword argument for which it was created. TensorFlow classes like `Layer`s and `Optimizer`s will automatically add - dependencies on their variables (e.g. "kernel" and "bias" for + dependencies on their own variables (e.g. "kernel" and "bias" for `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing dependencies easy in user-defined classes, since `Model` hooks into attribute assignment. For example: @@ -1840,7 +1853,7 @@ class Checkpoint(tracking.AutoTrackable): dtype=dtypes.int64, trainable=False)) - def write(self, file_prefix): + def write(self, file_prefix, options=None): """Writes a training checkpoint. The checkpoint includes variables created by this object and any @@ -1854,14 +1867,35 @@ class Checkpoint(tracking.AutoTrackable): Checkpoints written with `write` must be read with `read`. + Example usage: + + ``` + step = tf.Variable(0, name="step") + checkpoint = tf.Checkpoint(step=step) + checkpoint.write("/tmp/ckpt") + + # Later, read the checkpoint with read() + checkpoint.read("/tmp/ckpt").assert_consumed() + + # You can also pass options to write() and read(). For example this + # runs the IO ops on the localhost: + options = tf.CheckpointOptions(experimental_io_device="/job:localhost") + checkpoint.write("/tmp/ckpt", options=options) + + # Later, read the checkpoint with read() + checkpoint.read("/tmp/ckpt", options=options).assert_consumed() + ``` + Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). + options: Optional `tf.train.CheckpointOptions` object. Returns: The full path to the checkpoint (i.e. `file_prefix`). """ - output = self._saver.save(file_prefix=file_prefix) + options = options or checkpoint_options.CheckpointOptions() + output = self._saver.save(file_prefix=file_prefix, options=options) if tensor_util.is_tensor(output): if context.executing_eagerly(): return compat.as_str(output.numpy()) @@ -1884,7 +1918,7 @@ class Checkpoint(tracking.AutoTrackable): self._maybe_create_save_counter() return self._save_counter - def save(self, file_prefix): + def save(self, file_prefix, options=None): """Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any @@ -1898,14 +1932,33 @@ class Checkpoint(tracking.AutoTrackable): provided by other utilities which also wrap `write` and `read`. (`tf.train.CheckpointManager` for example). + ``` + step = tf.Variable(0, name="step") + checkpoint = tf.Checkpoint(step=step) + checkpoint.save("/tmp/ckpt") + + # Later, read the checkpoint with restore() + checkpoint.restore("/tmp/ckpt").assert_consumed() + + # You can also pass options to save() and restore(). For example this + # runs the IO ops on the localhost: + options = tf.CheckpointOptions(experimental_io_device="/job:localhost") + checkpoint.save("/tmp/ckpt", options=options) + + # Later, read the checkpoint with restore() + checkpoint.restore("/tmp/ckpt", options=options).assert_consumed() + ``` + Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`. + options: Optional `tf.train.CheckpointOptions` object. Returns: The full path to the checkpoint. """ + options = options or checkpoint_options.CheckpointOptions() graph_building = not context.executing_eagerly() if graph_building: if ops.inside_function(): @@ -1931,7 +1984,8 @@ class Checkpoint(tracking.AutoTrackable): checkpoint_number = session.run(self._save_assign_op) else: checkpoint_number = assign_op.numpy() - file_path = self.write("%s-%d" % (file_prefix, checkpoint_number)) + file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), + options=options) checkpoint_management.update_checkpoint_state_internal( save_dir=os.path.dirname(file_prefix), model_checkpoint_path=file_path, @@ -1939,7 +1993,7 @@ class Checkpoint(tracking.AutoTrackable): save_relative_paths=True) return file_path - def read(self, save_path): + def read(self, save_path, options=None): """Read a training checkpoint written with `write`. Reads this `Checkpoint` and any objects it depends on. @@ -1962,18 +2016,25 @@ class Checkpoint(tracking.AutoTrackable): # Later, load the checkpoint with read() # With restore() assert_consumed() would have failed. checkpoint.read(path).assert_consumed() + + # You can also pass options to restore(). For example this + # runs the IO ops on the localhost: + options = tf.CheckpointOptions(experimental_io_device="/job:localhost") + checkpoint.read(path, options=options) ``` Args: save_path: The path to the checkpoint as returned by `write`. + options: Optional `tf.train.CheckpointOptions` object. Returns: A load status object, which can be used to make assertions about the status of a checkpoint restoration. See `restore` for details. """ - return self._saver.restore(save_path=save_path) + options = options or checkpoint_options.CheckpointOptions() + return self._saver.restore(save_path=save_path, options=options) - def restore(self, save_path): + def restore(self, save_path, options=None): """Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. @@ -1995,6 +2056,10 @@ class Checkpoint(tracking.AutoTrackable): ```python checkpoint = tf.train.Checkpoint( ... ) checkpoint.restore(path).assert_consumed() + + # You can additionally pass options to restore(): + options = tf.CheckpointOptions(experimental_io_device="/job:localhost") + checkpoint.restore(path, options=options).assert_consumed() ``` An exception will be raised if any Python objects in the dependency graph @@ -2011,6 +2076,7 @@ class Checkpoint(tracking.AutoTrackable): `tf.train.latest_checkpoint`. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables. + options: Optional `tf.train.CheckpointOptions` object. Returns: A load status object, which can be used to make assertions about the @@ -2049,7 +2115,7 @@ class Checkpoint(tracking.AutoTrackable): checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). """ - status = self.read(save_path) + status = self.read(save_path, options=options) # Create the save counter now so it gets initialized with other variables # when graph building. Creating it earlier would lead to errors when using, # say, train.Saver() to save the model before initializing it. diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index a69a34c1038..7a96fedc89b 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import tracking @@ -409,6 +410,28 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): del ckpt status.assert_consumed() + @test_util.run_in_graph_and_eager_modes + def testPassingCheckpointOptions(self): + localhost = "/job:localhost/device:CPU:0" + options = checkpoint_options.CheckpointOptions( + experimental_io_device=localhost) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + v = variable_scope.get_variable(name="v", initializer=0.) + self.evaluate(v.initializer) + ckpt = trackable_utils.Checkpoint(v=v) + self.evaluate(trackable_utils.gather_initializers(ckpt)) + save_path = ckpt.save(file_prefix=prefix, options=options) + status = ckpt.restore(save_path=save_path, options=options) + del ckpt + status.assert_consumed() + + # In graph mode, verify that the save and restore ops were set to run on + # localhost. + if not context.executing_eagerly(): + for op in ops.get_default_graph().get_operations(): + if op.type in ("SaveV2", "RestoreV2"): + self.assertEqual(localhost, op.device) + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt index 98462326401..6a8163c1335 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt @@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions" tf_class { is_instance: "" is_instance: "" + member { + name: "experimental_io_device" + mtype: "" + } member { name: "function_aliases" mtype: "" @@ -16,6 +20,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt new file mode 100644 index 00000000000..b86e4cbb762 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.CheckpointOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "experimental_io_device" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt index c71bc4af3ec..f89c502a73b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt @@ -28,6 +28,10 @@ tf_module { name: "CheckpointManager" mtype: "" } + member { + name: "CheckpointOptions" + mtype: "" + } member { name: "CheckpointSaverHook" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt index 98462326401..6a8163c1335 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt @@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions" tf_class { is_instance: "" is_instance: "" + member { + name: "experimental_io_device" + mtype: "" + } member { name: "function_aliases" mtype: "" @@ -16,6 +20,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt new file mode 100644 index 00000000000..b86e4cbb762 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.CheckpointOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "experimental_io_device" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt index d7e93a0f937..56651271c13 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt @@ -14,18 +14,18 @@ tf_class { } member_method { name: "read" - argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "restore" - argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "save" - argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "write" - argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index 13dc9829d66..f354e5d6e1c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "CheckpointManager" mtype: "" } + member { + name: "CheckpointOptions" + mtype: "" + } member { name: "ClusterDef" mtype: ""