Add an option to choose the I/O Device for saving and loading models.

This option enables saving and restoring models to or from filesystems only
accessible from the localhost when using multiple devices.

The option is available to
 - Save models: tf.saved_model.save()
 - Checkpoints: tf.Checkpoint()

PiperOrigin-RevId: 307858098
Change-Id: I4cd0a81424e306f0eac40bfb30d5067dfc02d1be
This commit is contained in:
A. Unique TensorFlower 2020-04-22 11:23:53 -07:00 committed by TensorFlower Gardener
parent 0493a020d4
commit e853835634
18 changed files with 365 additions and 60 deletions

View File

@ -310,6 +310,7 @@ py_library(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/saving:checkpoint_options",
"//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base",

View File

@ -52,6 +52,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
@ -941,6 +942,7 @@ def save(obj, export_dir, signatures=None, options=None):
May not be called from within a function body.
@end_compatibility
"""
options = options or save_options.SaveOptions()
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.
@ -954,7 +956,10 @@ def save(obj, export_dir, signatures=None, options=None):
# Write the checkpoint, copy assets into the assets directory, and write out
# the SavedModel proto itself.
utils_impl.get_or_create_variables_dir(export_dir)
object_saver.save(utils_impl.get_variables_path(export_dir))
ckpt_options = checkpoint_options.CheckpointOptions(
experimental_io_device=options.experimental_io_device)
object_saver.save(utils_impl.get_variables_path(export_dir),
options=ckpt_options)
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
export_dir)
# Note that this needs to be the last file operation when saving the
@ -976,6 +981,7 @@ def save(obj, export_dir, signatures=None, options=None):
def export_meta_graph(obj, filename, signatures=None, options=None):
"""Exports the MetaGraph proto to a file."""
options = options or save_options.SaveOptions()
export_dir = os.path.dirname(filename)
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
obj, export_dir, signatures, options)
@ -1001,7 +1007,6 @@ def _build_meta_graph(obj, export_dir, signatures, options,
if not isinstance(obj, base.Trackable):
raise ValueError(
"Expected a Trackable object for export, got {}.".format(obj))
options = options or save_options.SaveOptions()
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
checkpoint_graph_view = _AugmentedGraphView(obj)

View File

@ -33,12 +33,14 @@ class SaveOptions(object):
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases")
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
"experimental_io_device")
def __init__(self,
namespace_whitelist=None,
save_debug_info=False,
function_aliases=None):
function_aliases=None,
experimental_io_device=None):
"""Creates an object that stores options for SavedModel saving.
Args:
@ -46,16 +48,15 @@ class SaveOptions(object):
when saving a model. Saving an object that uses namespaced ops must
explicitly add all namespaces to the whitelist. The namespaced ops must
be registered into the framework when loading the SavedModel.
save_debug_info: Boolean indicating whether debug information is saved.
If True, then a debug/saved_model_debug_info.pb file will be written
with the contents of a GraphDebugInfo binary protocol buffer containing
stack trace information for all ops and functions that are saved.
save_debug_info: Boolean indicating whether debug information is saved. If
True, then a debug/saved_model_debug_info.pb file will be written with
the contents of a GraphDebugInfo binary protocol buffer containing stack
trace information for all ops and functions that are saved.
function_aliases: Python dict. Mapping from string to object returned by
@tf.function.
A single tf.function can generate many ConcreteFunctions. If a
downstream tool wants to refer to all concrete functions generated by a
single tf.function you can use the `function_aliases` argument to store
a map from the alias name to all concrete function names.
@tf.function. A single tf.function can generate many ConcreteFunctions.
If a downstream tool wants to refer to all concrete functions generated
by a single tf.function you can use the `function_aliases` argument to
store a map from the alias name to all concrete function names.
E.g.
```python
class MyModel:
@ -77,11 +78,21 @@ class SaveOptions(object):
})
tf.saved_model.save(model, export_dir, signatures, options)
```
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to save to a local directory,
such as "/tmp" when running in a distributed setting. In that case pass
a device for the host where the "/tmp" directory is accessible.
"""
self.namespace_whitelist = _validate_namespace_whitelist(
namespace_whitelist)
self.save_debug_info = save_debug_info
self.function_aliases = function_aliases if function_aliases else dict()
self.experimental_io_device = experimental_io_device
def _validate_namespace_whitelist(namespace_whitelist):

View File

@ -577,6 +577,12 @@ class SavingOptionsTest(test.TestCase):
self.assertEqual(function_cache[0].name.decode("utf-8"),
list(function_aliases.keys())[0])
def test_accepts_io_device(self):
options = save_options.SaveOptions()
self.assertEqual(None, options.experimental_io_device)
options = save_options.SaveOptions(experimental_io_device="/job:localhost")
self.assertEqual("/job:localhost", options.experimental_io_device)
class AssetTests(test.TestCase):

View File

@ -12,11 +12,20 @@ package(
exports_files(["LICENSE"])
py_library(
name = "checkpoint_options",
srcs = ["checkpoint_options.py"],
deps = [
"//tensorflow/python:tf_export",
],
)
py_library(
name = "functional_saver",
srcs = ["functional_saver.py"],
srcs_version = "PY2AND3",
deps = [
":checkpoint_options",
":saveable_hook",
":saveable_object",
":saveable_object_util",
@ -31,6 +40,7 @@ cuda_py_test(
"functional_saver_test.py",
],
deps = [
":checkpoint_options",
":functional_saver",
":saveable_hook",
"//tensorflow/python/eager:test",

View File

@ -0,0 +1,58 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving Checkpoints."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
@tf_export("train.CheckpointOptions")
class CheckpointOptions(object):
"""Options for constructing a Checkpoint.
Used as the `_options` argument to the `tf.Checkpoint` constructor to adjust
how variables are saved.
Example: Run IO ops on "localhost" while saving a checkpoint:
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)
```
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("experimental_io_device",)
def __init__(self, experimental_io_device=None):
"""Creates an object that stores options for a Checkpoint.
Args:
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to save to a local directory,
such as "/tmp" when running in a distributed setting. In that case pass
a device for the host where the "/tmp" directory is accessible.
"""
self.experimental_io_device = experimental_io_device

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
@ -52,15 +53,17 @@ class _SingleDeviceSaver(object):
"Expected a list of SaveableObjects, got %s." % (saveable,))
self._saveable_objects = saveable_objects
def save(self, file_prefix):
def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
tensor_names = []
tensors = []
tensor_slices = []
@ -69,19 +72,22 @@ class _SingleDeviceSaver(object):
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
with ops.device("cpu:0"):
save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
def restore(self, file_prefix):
def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
A dictionary mapping from SaveableObject names to restore operations.
"""
options = options or checkpoint_options.CheckpointOptions()
restore_specs = []
tensor_structure = []
for saveable in self._saveable_objects:
@ -91,7 +97,8 @@ class _SingleDeviceSaver(object):
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
with ops.device("cpu:0"):
restore_device = options.experimental_io_device or "cpu:0"
with ops.device(restore_device):
restored_tensors = io_ops.restore_v2(
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
@ -190,15 +197,17 @@ class MultiDeviceSaver(object):
with ops.control_dependencies(restore_ops.values()):
return array_ops.identity(file_prefix)
def save(self, file_prefix):
def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
for callback in self._before_save_callbacks:
callback()
@ -253,32 +262,37 @@ class MultiDeviceSaver(object):
with ops.device(device):
# _SingleDeviceSaver will use the CPU device when necessary, but initial
# read operations should be placed on the SaveableObject's device.
sharded_saves.append(saver.save(shard_prefix))
sharded_saves.append(saver.save(shard_prefix, options))
with ops.control_dependencies(sharded_saves):
# Co-locates the merge step with the last device.
with ops.device(saveable_object_util.set_cpu0(last_device)):
# Merge on the io_device if specified, otherwise co-locates the merge op
# with the last device used.
merge_device = (options.experimental_io_device or
saveable_object_util.set_cpu0(last_device))
with ops.device(merge_device):
# V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
return gen_io_ops.merge_v2_checkpoints(
sharded_prefixes, file_prefix, delete_old_dirs=True)
def restore(self, file_prefix):
def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
A dictionary mapping from SaveableObject names to restore operations.
"""
options = options or checkpoint_options.CheckpointOptions()
restore_ops = {}
# Sort by device name to avoid propagating non-deterministic dictionary
# ordering in some Python versions.
for device, saver in sorted(self._single_device_savers.items()):
with ops.device(device):
restore_ops.update(saver.restore(file_prefix))
restore_ops.update(saver.restore(file_prefix, options))
for callback in self._after_restore_callbacks:
callback()

View File

@ -20,21 +20,37 @@ from __future__ import print_function
import os
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object_util
LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0"
class SaverTest(test.TestCase):
def setUp(self):
super(SaverTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 3 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
self.local_options = checkpoint_options.CheckpointOptions(
experimental_io_device=LOCALHOST)
@test_util.run_in_graph_and_eager_modes
def test_resource_variable(self):
v1 = resource_variable_ops.ResourceVariable(2.)
@ -55,6 +71,33 @@ class SaverTest(test.TestCase):
self.evaluate(second_saver.restore(prefix))
self.assertEqual(2., self.evaluate(v2))
@test_util.run_in_graph_and_eager_modes
def test_resource_variable_use_localhost(self):
v1 = resource_variable_ops.ResourceVariable(2.)
self.evaluate(v1.initializer)
saver = functional_saver._SingleDeviceSaver(
saveable_object_util.saveable_objects_for_op(v1, "x"))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
self.assertEqual(2, len(gfile.Glob(prefix + "*")))
self.evaluate(v1.assign(1.))
self.evaluate(saver.restore(prefix, self.local_options))
self.assertEqual(2., self.evaluate(v1))
v2 = resource_variable_ops.ResourceVariable(3.)
self.evaluate(v2.initializer)
second_saver = functional_saver._SingleDeviceSaver(
saveable_object_util.saveable_objects_for_op(v2, "x"))
self.evaluate(second_saver.restore(prefix, self.local_options))
self.assertEqual(2., self.evaluate(v2))
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2"):
self.assertEqual(LOCALHOST, op.device)
def test_to_proto(self):
v1 = resource_variable_ops.ResourceVariable(2.)
saver = functional_saver.MultiDeviceSaver(
@ -83,12 +126,7 @@ class SaverTest(test.TestCase):
second_saver.restore(save_path)
self.assertEqual(2., self.evaluate(v2))
@test_util.run_v1_only(
"Needs an API to setup multiple devices, b/124805129")
# Set up multiple devices when graph building. Before test.main() we configure
# the devices for eager execution.
@test_util.run_in_graph_and_eager_modes(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
@test_util.run_in_graph_and_eager_modes
def test_checkpoint_is_sharded_by_device(self):
with ops.device("cpu:0"):
v0 = resource_variable_ops.ResourceVariable(0.)
@ -99,9 +137,9 @@ class SaverTest(test.TestCase):
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
saver = functional_saver.MultiDeviceSaver(
list(saveable_object_util.saveable_objects_for_op(v0, "v0"))
+ list(saveable_object_util.saveable_objects_for_op(v1, "v1"))
+ list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix)))
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
@ -113,8 +151,38 @@ class SaverTest(test.TestCase):
self.assertEqual(1., self.evaluate(v1))
self.assertEqual(2., self.evaluate(v2))
@test_util.run_in_graph_and_eager_modes
def test_checkpoint_multi_device_using_localhost(self):
with ops.device("cpu:0"):
v0 = resource_variable_ops.ResourceVariable(0.)
with ops.device("cpu:1"):
v1 = resource_variable_ops.ResourceVariable(1.)
with ops.device("cpu:2"):
v2 = resource_variable_ops.ResourceVariable(2.)
class SaveableHookTest(test.TestCase):
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
saver = functional_saver.MultiDeviceSaver(
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
self.evaluate(v0.assign(-1.))
self.evaluate(v1.assign(-1.))
self.evaluate(v2.assign(-1.))
self.evaluate(
saver.restore(constant_op.constant(prefix), self.local_options))
self.assertEqual(0., self.evaluate(v0))
self.assertEqual(1., self.evaluate(v1))
self.assertEqual(2., self.evaluate(v2))
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
self.assertEqual(LOCALHOST, op.device)
def test_callbacks_run(self):
# Use dict because an int would be shadowed inside callback.
@ -144,6 +212,5 @@ class SaveableHookTest(test.TestCase):
if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
ops.enable_eager_execution()
test.main()

View File

@ -150,6 +150,7 @@ py_library(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/keras:backend",
"//tensorflow/python/training/saving:checkpoint_options",
"//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/saving:saveable_object_util",
"@six_archive//:six",
@ -191,6 +192,7 @@ tf_py_test(
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras/layers",
"//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/training/saving:checkpoint_options",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],

View File

@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
@ -168,7 +169,7 @@ class _CheckpointRestoreCoordinator(object):
"""Holds the status of an object-based checkpoint load."""
def __init__(self, object_graph_proto, save_path, save_path_tensor,
restore_op_cache, graph_view):
restore_op_cache, graph_view, options):
"""Specify the checkpoint being loaded.
Args:
@ -184,7 +185,9 @@ class _CheckpointRestoreCoordinator(object):
`restore()` calls.
graph_view: A graph_view_lib.ObjectGraphView object for the restored
objects.
options: A CheckpointOptions object.
"""
self.options = options
self.object_graph_proto = object_graph_proto
self.restore_uid = ops.uid()
# Maps from proto ids to lists of attributes which were in the checkpoint
@ -291,7 +294,7 @@ class _CheckpointRestoreCoordinator(object):
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (tensor_saveables.keys(), validated_names))
new_restore_ops = functional_saver.MultiDeviceSaver(
validated_saveables).restore(self.save_path_tensor)
validated_saveables).restore(self.save_path_tensor, self.options)
if not context.executing_eagerly():
for name, restore_op in sorted(new_restore_ops.items()):
restore_ops.append(restore_op)
@ -1113,13 +1116,15 @@ class TrackableSaver(object):
def _save_cached_when_graph_building(self,
file_prefix,
object_graph_tensor=None):
object_graph_tensor,
options):
"""Create or retrieve save ops.
Args:
file_prefix: The prefix for saved checkpoint files.
object_graph_tensor: A `Tensor` to which the current object graph will be
fed.
options: `CheckpointOptions` object.
Returns:
A two-element tuple with a filename tensor and a feed_dict of tensors to
@ -1137,14 +1142,15 @@ class TrackableSaver(object):
# var_list.
or context.executing_eagerly() or ops.inside_function()):
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
save_op = saver.save(file_prefix)
save_op = saver.save(file_prefix, options=options)
with ops.device("/cpu:0"):
with ops.control_dependencies([save_op]):
self._cached_save_operation = array_ops.identity(file_prefix)
self._last_save_object_graph = graph_proto
return self._cached_save_operation, feed_additions
def save(self, file_prefix, checkpoint_number=None, session=None):
def save(self, file_prefix, checkpoint_number=None, session=None,
options=None):
"""Save a training checkpoint.
The saved checkpoint includes variables created by this object and any
@ -1162,10 +1168,12 @@ class TrackableSaver(object):
session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
The full path to the checkpoint.
"""
options = options or checkpoint_options.CheckpointOptions()
feed_dict = {}
use_session = (not context.executing_eagerly() and
not ops.inside_function())
@ -1189,7 +1197,7 @@ class TrackableSaver(object):
file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
file_prefix_tensor, object_graph_tensor, options)
if new_feed_additions:
feed_dict.update(new_feed_additions)
if not use_session:
@ -1202,7 +1210,7 @@ class TrackableSaver(object):
else:
return save_path
def restore(self, save_path):
def restore(self, save_path, options=None):
"""Restore a training checkpoint.
Restores `root_trackable` and any objects that it tracks
@ -1250,6 +1258,7 @@ class TrackableSaver(object):
object which may run initializers for objects in the dependency graph.
If the checkpoint was written by the name-based
`tf.compat.v1.train.Saver`, names are used to match variables.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
A load status object, which can be used to make assertions about the
@ -1260,6 +1269,7 @@ class TrackableSaver(object):
If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
object is returned which runs restore ops from a name-based saver.
"""
options = options or checkpoint_options.CheckpointOptions()
if save_path is None:
return InitializationOnlyStatus(self._graph_view, ops.uid())
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
@ -1304,7 +1314,8 @@ class TrackableSaver(object):
save_path=save_path,
save_path_tensor=file_prefix_tensor,
restore_op_cache=self._restore_op_cache,
graph_view=self._graph_view)
graph_view=self._graph_view,
options=options)
base.CheckpointPosition(
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
load_status = CheckpointLoadStatus(
@ -1736,6 +1747,8 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
@ -1744,7 +1757,7 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint.save(file_prefix=checkpoint_prefix)
```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
`Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
writes and
reads `variable.name` based checkpoints. Object-based checkpointing saves a
@ -1757,7 +1770,7 @@ class Checkpoint(tracking.AutoTrackable):
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
dependencies on their variables (e.g. "kernel" and "bias" for
dependencies on their own variables (e.g. "kernel" and "bias" for
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
dependencies easy in user-defined classes, since `Model` hooks into attribute
assignment. For example:
@ -1840,7 +1853,7 @@ class Checkpoint(tracking.AutoTrackable):
dtype=dtypes.int64,
trainable=False))
def write(self, file_prefix):
def write(self, file_prefix, options=None):
"""Writes a training checkpoint.
The checkpoint includes variables created by this object and any
@ -1854,14 +1867,35 @@ class Checkpoint(tracking.AutoTrackable):
Checkpoints written with `write` must be read with `read`.
Example usage:
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.write("/tmp/ckpt")
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt").assert_consumed()
# You can also pass options to write() and read(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("/tmp/ckpt", options=options)
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
```
Args:
file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix).
options: Optional `tf.train.CheckpointOptions` object.
Returns:
The full path to the checkpoint (i.e. `file_prefix`).
"""
output = self._saver.save(file_prefix=file_prefix)
options = options or checkpoint_options.CheckpointOptions()
output = self._saver.save(file_prefix=file_prefix, options=options)
if tensor_util.is_tensor(output):
if context.executing_eagerly():
return compat.as_str(output.numpy())
@ -1884,7 +1918,7 @@ class Checkpoint(tracking.AutoTrackable):
self._maybe_create_save_counter()
return self._save_counter
def save(self, file_prefix):
def save(self, file_prefix, options=None):
"""Saves a training checkpoint and provides basic checkpoint management.
The saved checkpoint includes variables created by this object and any
@ -1898,14 +1932,33 @@ class Checkpoint(tracking.AutoTrackable):
provided by other utilities which also wrap `write` and `read`.
(`tf.train.CheckpointManager` for example).
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.save("/tmp/ckpt")
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt").assert_consumed()
# You can also pass options to save() and restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
```
Args:
file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix). Names are generated based on this
prefix and `Checkpoint.save_counter`.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
The full path to the checkpoint.
"""
options = options or checkpoint_options.CheckpointOptions()
graph_building = not context.executing_eagerly()
if graph_building:
if ops.inside_function():
@ -1931,7 +1984,8 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint_number = session.run(self._save_assign_op)
else:
checkpoint_number = assign_op.numpy()
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number))
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
options=options)
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(file_prefix),
model_checkpoint_path=file_path,
@ -1939,7 +1993,7 @@ class Checkpoint(tracking.AutoTrackable):
save_relative_paths=True)
return file_path
def read(self, save_path):
def read(self, save_path, options=None):
"""Read a training checkpoint written with `write`.
Reads this `Checkpoint` and any objects it depends on.
@ -1962,18 +2016,25 @@ class Checkpoint(tracking.AutoTrackable):
# Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed()
# You can also pass options to restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.read(path, options=options)
```
Args:
save_path: The path to the checkpoint as returned by `write`.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See `restore` for details.
"""
return self._saver.restore(save_path=save_path)
options = options or checkpoint_options.CheckpointOptions()
return self._saver.restore(save_path=save_path, options=options)
def restore(self, save_path):
def restore(self, save_path, options=None):
"""Restore a training checkpoint.
Restores this `Checkpoint` and any objects it depends on.
@ -1995,6 +2056,10 @@ class Checkpoint(tracking.AutoTrackable):
```python
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path).assert_consumed()
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()
```
An exception will be raised if any Python objects in the dependency graph
@ -2011,6 +2076,7 @@ class Checkpoint(tracking.AutoTrackable):
`tf.train.latest_checkpoint`. If the checkpoint was written by the
name-based `tf.compat.v1.train.Saver`, names are used to match
variables.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
A load status object, which can be used to make assertions about the
@ -2049,7 +2115,7 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint file or object when the `Checkpoint` object is deleted
(often at program shutdown).
"""
status = self.read(save_path)
status = self.read(save_path, options=options)
# Create the save counter now so it gets initialized with other variables
# when graph building. Creating it earlier would lead to errors when using,
# say, train.Saver() to save the model before initializing it.

View File

@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
@ -409,6 +410,28 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
del ckpt
status.assert_consumed()
@test_util.run_in_graph_and_eager_modes
def testPassingCheckpointOptions(self):
localhost = "/job:localhost/device:CPU:0"
options = checkpoint_options.CheckpointOptions(
experimental_io_device=localhost)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
v = variable_scope.get_variable(name="v", initializer=0.)
self.evaluate(v.initializer)
ckpt = trackable_utils.Checkpoint(v=v)
self.evaluate(trackable_utils.gather_initializers(ckpt))
save_path = ckpt.save(file_prefix=prefix, options=options)
status = ckpt.restore(save_path=save_path, options=options)
del ckpt
status.assert_consumed()
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2"):
self.assertEqual(localhost, op.device)
@test_util.run_in_graph_and_eager_modes
def testSaveRestore(self):
model = MyModel()

View File

@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member {
name: "function_aliases"
mtype: "<type \'member_descriptor\'>"
@ -16,6 +20,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,13 @@
path: "tensorflow.train.CheckpointOptions"
tf_class {
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -28,6 +28,10 @@ tf_module {
name: "CheckpointManager"
mtype: "<type \'type\'>"
}
member {
name: "CheckpointOptions"
mtype: "<type \'type\'>"
}
member {
name: "CheckpointSaverHook"
mtype: "<type \'type\'>"

View File

@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member {
name: "function_aliases"
mtype: "<type \'member_descriptor\'>"
@ -16,6 +20,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,13 @@
path: "tensorflow.train.CheckpointOptions"
tf_class {
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -14,18 +14,18 @@ tf_class {
}
member_method {
name: "read"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "restore"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "save"
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "write"
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -12,6 +12,10 @@ tf_module {
name: "CheckpointManager"
mtype: "<type \'type\'>"
}
member {
name: "CheckpointOptions"
mtype: "<type \'type\'>"
}
member {
name: "ClusterDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"