Add an option to choose the I/O Device for saving and loading models.
This option enables saving and restoring models to or from filesystems only accessible from the localhost when using multiple devices. The option is available to - Save models: tf.saved_model.save() - Checkpoints: tf.Checkpoint() PiperOrigin-RevId: 307858098 Change-Id: I4cd0a81424e306f0eac40bfb30d5067dfc02d1be
This commit is contained in:
parent
0493a020d4
commit
e853835634
|
@ -310,6 +310,7 @@ py_library(
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:function",
|
||||||
|
"//tensorflow/python/training/saving:checkpoint_options",
|
||||||
"//tensorflow/python/training/saving:functional_saver",
|
"//tensorflow/python/training/saving:functional_saver",
|
||||||
"//tensorflow/python/training/tracking",
|
"//tensorflow/python/training/tracking",
|
||||||
"//tensorflow/python/training/tracking:base",
|
"//tensorflow/python/training/tracking:base",
|
||||||
|
|
|
@ -52,6 +52,7 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.saved_model import signature_serialization
|
from tensorflow.python.saved_model import signature_serialization
|
||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
from tensorflow.python.saved_model import utils_impl
|
from tensorflow.python.saved_model import utils_impl
|
||||||
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
|
@ -941,6 +942,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||||
May not be called from within a function body.
|
May not be called from within a function body.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
|
options = options or save_options.SaveOptions()
|
||||||
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
||||||
# compatible (no sessions) and share it with this export API rather than
|
# compatible (no sessions) and share it with this export API rather than
|
||||||
# making a SavedModel proto and writing it directly.
|
# making a SavedModel proto and writing it directly.
|
||||||
|
@ -954,7 +956,10 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||||
# Write the checkpoint, copy assets into the assets directory, and write out
|
# Write the checkpoint, copy assets into the assets directory, and write out
|
||||||
# the SavedModel proto itself.
|
# the SavedModel proto itself.
|
||||||
utils_impl.get_or_create_variables_dir(export_dir)
|
utils_impl.get_or_create_variables_dir(export_dir)
|
||||||
object_saver.save(utils_impl.get_variables_path(export_dir))
|
ckpt_options = checkpoint_options.CheckpointOptions(
|
||||||
|
experimental_io_device=options.experimental_io_device)
|
||||||
|
object_saver.save(utils_impl.get_variables_path(export_dir),
|
||||||
|
options=ckpt_options)
|
||||||
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
|
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
|
||||||
export_dir)
|
export_dir)
|
||||||
# Note that this needs to be the last file operation when saving the
|
# Note that this needs to be the last file operation when saving the
|
||||||
|
@ -976,6 +981,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||||
|
|
||||||
def export_meta_graph(obj, filename, signatures=None, options=None):
|
def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||||
"""Exports the MetaGraph proto to a file."""
|
"""Exports the MetaGraph proto to a file."""
|
||||||
|
options = options or save_options.SaveOptions()
|
||||||
export_dir = os.path.dirname(filename)
|
export_dir = os.path.dirname(filename)
|
||||||
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
|
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
|
||||||
obj, export_dir, signatures, options)
|
obj, export_dir, signatures, options)
|
||||||
|
@ -1001,7 +1007,6 @@ def _build_meta_graph(obj, export_dir, signatures, options,
|
||||||
if not isinstance(obj, base.Trackable):
|
if not isinstance(obj, base.Trackable):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected a Trackable object for export, got {}.".format(obj))
|
"Expected a Trackable object for export, got {}.".format(obj))
|
||||||
options = options or save_options.SaveOptions()
|
|
||||||
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
|
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
|
||||||
|
|
||||||
checkpoint_graph_view = _AugmentedGraphView(obj)
|
checkpoint_graph_view = _AugmentedGraphView(obj)
|
||||||
|
|
|
@ -33,12 +33,14 @@ class SaveOptions(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Define object attributes in __slots__ for improved memory and performance.
|
# Define object attributes in __slots__ for improved memory and performance.
|
||||||
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases")
|
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
|
||||||
|
"experimental_io_device")
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
namespace_whitelist=None,
|
namespace_whitelist=None,
|
||||||
save_debug_info=False,
|
save_debug_info=False,
|
||||||
function_aliases=None):
|
function_aliases=None,
|
||||||
|
experimental_io_device=None):
|
||||||
"""Creates an object that stores options for SavedModel saving.
|
"""Creates an object that stores options for SavedModel saving.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -46,16 +48,15 @@ class SaveOptions(object):
|
||||||
when saving a model. Saving an object that uses namespaced ops must
|
when saving a model. Saving an object that uses namespaced ops must
|
||||||
explicitly add all namespaces to the whitelist. The namespaced ops must
|
explicitly add all namespaces to the whitelist. The namespaced ops must
|
||||||
be registered into the framework when loading the SavedModel.
|
be registered into the framework when loading the SavedModel.
|
||||||
save_debug_info: Boolean indicating whether debug information is saved.
|
save_debug_info: Boolean indicating whether debug information is saved. If
|
||||||
If True, then a debug/saved_model_debug_info.pb file will be written
|
True, then a debug/saved_model_debug_info.pb file will be written with
|
||||||
with the contents of a GraphDebugInfo binary protocol buffer containing
|
the contents of a GraphDebugInfo binary protocol buffer containing stack
|
||||||
stack trace information for all ops and functions that are saved.
|
trace information for all ops and functions that are saved.
|
||||||
function_aliases: Python dict. Mapping from string to object returned by
|
function_aliases: Python dict. Mapping from string to object returned by
|
||||||
@tf.function.
|
@tf.function. A single tf.function can generate many ConcreteFunctions.
|
||||||
A single tf.function can generate many ConcreteFunctions. If a
|
If a downstream tool wants to refer to all concrete functions generated
|
||||||
downstream tool wants to refer to all concrete functions generated by a
|
by a single tf.function you can use the `function_aliases` argument to
|
||||||
single tf.function you can use the `function_aliases` argument to store
|
store a map from the alias name to all concrete function names.
|
||||||
a map from the alias name to all concrete function names.
|
|
||||||
E.g.
|
E.g.
|
||||||
```python
|
```python
|
||||||
class MyModel:
|
class MyModel:
|
||||||
|
@ -77,11 +78,21 @@ class SaveOptions(object):
|
||||||
})
|
})
|
||||||
tf.saved_model.save(model, export_dir, signatures, options)
|
tf.saved_model.save(model, export_dir, signatures, options)
|
||||||
```
|
```
|
||||||
|
experimental_io_device: string. Applies in a distributed setting.
|
||||||
|
Tensorflow device to use to access the filesystem. If `None` (default)
|
||||||
|
then for each variable the filesystem is accessed from the CPU:0 device
|
||||||
|
of the host where that variable is assigned. If specified, the
|
||||||
|
filesystem is instead accessed from that device for all variables.
|
||||||
|
|
||||||
|
This is for example useful if you want to save to a local directory,
|
||||||
|
such as "/tmp" when running in a distributed setting. In that case pass
|
||||||
|
a device for the host where the "/tmp" directory is accessible.
|
||||||
"""
|
"""
|
||||||
self.namespace_whitelist = _validate_namespace_whitelist(
|
self.namespace_whitelist = _validate_namespace_whitelist(
|
||||||
namespace_whitelist)
|
namespace_whitelist)
|
||||||
self.save_debug_info = save_debug_info
|
self.save_debug_info = save_debug_info
|
||||||
self.function_aliases = function_aliases if function_aliases else dict()
|
self.function_aliases = function_aliases if function_aliases else dict()
|
||||||
|
self.experimental_io_device = experimental_io_device
|
||||||
|
|
||||||
|
|
||||||
def _validate_namespace_whitelist(namespace_whitelist):
|
def _validate_namespace_whitelist(namespace_whitelist):
|
||||||
|
|
|
@ -577,6 +577,12 @@ class SavingOptionsTest(test.TestCase):
|
||||||
self.assertEqual(function_cache[0].name.decode("utf-8"),
|
self.assertEqual(function_cache[0].name.decode("utf-8"),
|
||||||
list(function_aliases.keys())[0])
|
list(function_aliases.keys())[0])
|
||||||
|
|
||||||
|
def test_accepts_io_device(self):
|
||||||
|
options = save_options.SaveOptions()
|
||||||
|
self.assertEqual(None, options.experimental_io_device)
|
||||||
|
options = save_options.SaveOptions(experimental_io_device="/job:localhost")
|
||||||
|
self.assertEqual("/job:localhost", options.experimental_io_device)
|
||||||
|
|
||||||
|
|
||||||
class AssetTests(test.TestCase):
|
class AssetTests(test.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -12,11 +12,20 @@ package(
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "checkpoint_options",
|
||||||
|
srcs = ["checkpoint_options.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:tf_export",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "functional_saver",
|
name = "functional_saver",
|
||||||
srcs = ["functional_saver.py"],
|
srcs = ["functional_saver.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":checkpoint_options",
|
||||||
":saveable_hook",
|
":saveable_hook",
|
||||||
":saveable_object",
|
":saveable_object",
|
||||||
":saveable_object_util",
|
":saveable_object_util",
|
||||||
|
@ -31,6 +40,7 @@ cuda_py_test(
|
||||||
"functional_saver_test.py",
|
"functional_saver_test.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":checkpoint_options",
|
||||||
":functional_saver",
|
":functional_saver",
|
||||||
":saveable_hook",
|
":saveable_hook",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Options for saving Checkpoints."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("train.CheckpointOptions")
|
||||||
|
class CheckpointOptions(object):
|
||||||
|
"""Options for constructing a Checkpoint.
|
||||||
|
|
||||||
|
Used as the `_options` argument to the `tf.Checkpoint` constructor to adjust
|
||||||
|
how variables are saved.
|
||||||
|
|
||||||
|
Example: Run IO ops on "localhost" while saving a checkpoint:
|
||||||
|
|
||||||
|
```
|
||||||
|
step = tf.Variable(0, name="step")
|
||||||
|
checkpoint = tf.Checkpoint(step=step)
|
||||||
|
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
|
||||||
|
checkpoint.save("/tmp/ckpt", options=options)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define object attributes in __slots__ for improved memory and performance.
|
||||||
|
__slots__ = ("experimental_io_device",)
|
||||||
|
|
||||||
|
def __init__(self, experimental_io_device=None):
|
||||||
|
"""Creates an object that stores options for a Checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experimental_io_device: string. Applies in a distributed setting.
|
||||||
|
Tensorflow device to use to access the filesystem. If `None` (default)
|
||||||
|
then for each variable the filesystem is accessed from the CPU:0 device
|
||||||
|
of the host where that variable is assigned. If specified, the
|
||||||
|
filesystem is instead accessed from that device for all variables.
|
||||||
|
|
||||||
|
This is for example useful if you want to save to a local directory,
|
||||||
|
such as "/tmp" when running in a distributed setting. In that case pass
|
||||||
|
a device for the host where the "/tmp" directory is accessible.
|
||||||
|
"""
|
||||||
|
self.experimental_io_device = experimental_io_device
|
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_io_ops
|
from tensorflow.python.ops import gen_io_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import saveable_hook
|
from tensorflow.python.training.saving import saveable_hook
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
@ -52,15 +53,17 @@ class _SingleDeviceSaver(object):
|
||||||
"Expected a list of SaveableObjects, got %s." % (saveable,))
|
"Expected a list of SaveableObjects, got %s." % (saveable,))
|
||||||
self._saveable_objects = saveable_objects
|
self._saveable_objects = saveable_objects
|
||||||
|
|
||||||
def save(self, file_prefix):
|
def save(self, file_prefix, options=None):
|
||||||
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A string or scalar string Tensor containing the prefix to
|
file_prefix: A string or scalar string Tensor containing the prefix to
|
||||||
save under.
|
save under.
|
||||||
|
options: Optional `CheckpointOptions` object.
|
||||||
Returns:
|
Returns:
|
||||||
An `Operation`, or None when executing eagerly.
|
An `Operation`, or None when executing eagerly.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
tensor_names = []
|
tensor_names = []
|
||||||
tensors = []
|
tensors = []
|
||||||
tensor_slices = []
|
tensor_slices = []
|
||||||
|
@ -69,19 +72,22 @@ class _SingleDeviceSaver(object):
|
||||||
tensor_names.append(spec.name)
|
tensor_names.append(spec.name)
|
||||||
tensors.append(spec.tensor)
|
tensors.append(spec.tensor)
|
||||||
tensor_slices.append(spec.slice_spec)
|
tensor_slices.append(spec.slice_spec)
|
||||||
with ops.device("cpu:0"):
|
save_device = options.experimental_io_device or "cpu:0"
|
||||||
|
with ops.device(save_device):
|
||||||
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
|
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
|
||||||
|
|
||||||
def restore(self, file_prefix):
|
def restore(self, file_prefix, options=None):
|
||||||
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A string or scalar string Tensor containing the prefix for
|
file_prefix: A string or scalar string Tensor containing the prefix for
|
||||||
files to read from.
|
files to read from.
|
||||||
|
options: Optional `CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping from SaveableObject names to restore operations.
|
A dictionary mapping from SaveableObject names to restore operations.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
restore_specs = []
|
restore_specs = []
|
||||||
tensor_structure = []
|
tensor_structure = []
|
||||||
for saveable in self._saveable_objects:
|
for saveable in self._saveable_objects:
|
||||||
|
@ -91,7 +97,8 @@ class _SingleDeviceSaver(object):
|
||||||
saveable_tensor_structure.append(spec.name)
|
saveable_tensor_structure.append(spec.name)
|
||||||
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
|
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
|
||||||
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
|
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
|
||||||
with ops.device("cpu:0"):
|
restore_device = options.experimental_io_device or "cpu:0"
|
||||||
|
with ops.device(restore_device):
|
||||||
restored_tensors = io_ops.restore_v2(
|
restored_tensors = io_ops.restore_v2(
|
||||||
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
|
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
|
||||||
structured_restored_tensors = nest.pack_sequence_as(
|
structured_restored_tensors = nest.pack_sequence_as(
|
||||||
|
@ -190,15 +197,17 @@ class MultiDeviceSaver(object):
|
||||||
with ops.control_dependencies(restore_ops.values()):
|
with ops.control_dependencies(restore_ops.values()):
|
||||||
return array_ops.identity(file_prefix)
|
return array_ops.identity(file_prefix)
|
||||||
|
|
||||||
def save(self, file_prefix):
|
def save(self, file_prefix, options=None):
|
||||||
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A string or scalar string Tensor containing the prefix to
|
file_prefix: A string or scalar string Tensor containing the prefix to
|
||||||
save under.
|
save under.
|
||||||
|
options: Optional `CheckpointOptions` object.
|
||||||
Returns:
|
Returns:
|
||||||
An `Operation`, or None when executing eagerly.
|
An `Operation`, or None when executing eagerly.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
for callback in self._before_save_callbacks:
|
for callback in self._before_save_callbacks:
|
||||||
callback()
|
callback()
|
||||||
|
|
||||||
|
@ -253,32 +262,37 @@ class MultiDeviceSaver(object):
|
||||||
with ops.device(device):
|
with ops.device(device):
|
||||||
# _SingleDeviceSaver will use the CPU device when necessary, but initial
|
# _SingleDeviceSaver will use the CPU device when necessary, but initial
|
||||||
# read operations should be placed on the SaveableObject's device.
|
# read operations should be placed on the SaveableObject's device.
|
||||||
sharded_saves.append(saver.save(shard_prefix))
|
sharded_saves.append(saver.save(shard_prefix, options))
|
||||||
|
|
||||||
with ops.control_dependencies(sharded_saves):
|
with ops.control_dependencies(sharded_saves):
|
||||||
# Co-locates the merge step with the last device.
|
# Merge on the io_device if specified, otherwise co-locates the merge op
|
||||||
with ops.device(saveable_object_util.set_cpu0(last_device)):
|
# with the last device used.
|
||||||
|
merge_device = (options.experimental_io_device or
|
||||||
|
saveable_object_util.set_cpu0(last_device))
|
||||||
|
with ops.device(merge_device):
|
||||||
# V2 format write path consists of a metadata merge step. Once merged,
|
# V2 format write path consists of a metadata merge step. Once merged,
|
||||||
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
||||||
return gen_io_ops.merge_v2_checkpoints(
|
return gen_io_ops.merge_v2_checkpoints(
|
||||||
sharded_prefixes, file_prefix, delete_old_dirs=True)
|
sharded_prefixes, file_prefix, delete_old_dirs=True)
|
||||||
|
|
||||||
def restore(self, file_prefix):
|
def restore(self, file_prefix, options=None):
|
||||||
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A string or scalar string Tensor containing the prefix for
|
file_prefix: A string or scalar string Tensor containing the prefix for
|
||||||
files to read from.
|
files to read from.
|
||||||
|
options: Optional `CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping from SaveableObject names to restore operations.
|
A dictionary mapping from SaveableObject names to restore operations.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
restore_ops = {}
|
restore_ops = {}
|
||||||
# Sort by device name to avoid propagating non-deterministic dictionary
|
# Sort by device name to avoid propagating non-deterministic dictionary
|
||||||
# ordering in some Python versions.
|
# ordering in some Python versions.
|
||||||
for device, saver in sorted(self._single_device_savers.items()):
|
for device, saver in sorted(self._single_device_savers.items()):
|
||||||
with ops.device(device):
|
with ops.device(device):
|
||||||
restore_ops.update(saver.restore(file_prefix))
|
restore_ops.update(saver.restore(file_prefix, options))
|
||||||
|
|
||||||
for callback in self._after_restore_callbacks:
|
for callback in self._after_restore_callbacks:
|
||||||
callback()
|
callback()
|
||||||
|
|
|
@ -20,21 +20,37 @@ from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
from tensorflow.python.training.saving import saveable_hook
|
from tensorflow.python.training.saving import saveable_hook
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
|
LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
|
|
||||||
|
|
||||||
class SaverTest(test.TestCase):
|
class SaverTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(SaverTest, self).setUp()
|
||||||
|
cpus = config.list_physical_devices("CPU")
|
||||||
|
# Set 3 virtual CPUs
|
||||||
|
config.set_logical_device_configuration(cpus[0], [
|
||||||
|
context.LogicalDeviceConfiguration(),
|
||||||
|
context.LogicalDeviceConfiguration(),
|
||||||
|
context.LogicalDeviceConfiguration()
|
||||||
|
])
|
||||||
|
self.local_options = checkpoint_options.CheckpointOptions(
|
||||||
|
experimental_io_device=LOCALHOST)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_resource_variable(self):
|
def test_resource_variable(self):
|
||||||
v1 = resource_variable_ops.ResourceVariable(2.)
|
v1 = resource_variable_ops.ResourceVariable(2.)
|
||||||
|
@ -55,6 +71,33 @@ class SaverTest(test.TestCase):
|
||||||
self.evaluate(second_saver.restore(prefix))
|
self.evaluate(second_saver.restore(prefix))
|
||||||
self.assertEqual(2., self.evaluate(v2))
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_resource_variable_use_localhost(self):
|
||||||
|
v1 = resource_variable_ops.ResourceVariable(2.)
|
||||||
|
self.evaluate(v1.initializer)
|
||||||
|
saver = functional_saver._SingleDeviceSaver(
|
||||||
|
saveable_object_util.saveable_objects_for_op(v1, "x"))
|
||||||
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
|
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
|
||||||
|
self.assertEqual(2, len(gfile.Glob(prefix + "*")))
|
||||||
|
self.evaluate(v1.assign(1.))
|
||||||
|
self.evaluate(saver.restore(prefix, self.local_options))
|
||||||
|
self.assertEqual(2., self.evaluate(v1))
|
||||||
|
|
||||||
|
v2 = resource_variable_ops.ResourceVariable(3.)
|
||||||
|
self.evaluate(v2.initializer)
|
||||||
|
second_saver = functional_saver._SingleDeviceSaver(
|
||||||
|
saveable_object_util.saveable_objects_for_op(v2, "x"))
|
||||||
|
self.evaluate(second_saver.restore(prefix, self.local_options))
|
||||||
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
# In graph mode, verify that the save and restore ops were set to run on
|
||||||
|
# localhost.
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
for op in ops.get_default_graph().get_operations():
|
||||||
|
if op.type in ("SaveV2", "RestoreV2"):
|
||||||
|
self.assertEqual(LOCALHOST, op.device)
|
||||||
|
|
||||||
def test_to_proto(self):
|
def test_to_proto(self):
|
||||||
v1 = resource_variable_ops.ResourceVariable(2.)
|
v1 = resource_variable_ops.ResourceVariable(2.)
|
||||||
saver = functional_saver.MultiDeviceSaver(
|
saver = functional_saver.MultiDeviceSaver(
|
||||||
|
@ -83,12 +126,7 @@ class SaverTest(test.TestCase):
|
||||||
second_saver.restore(save_path)
|
second_saver.restore(save_path)
|
||||||
self.assertEqual(2., self.evaluate(v2))
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
@test_util.run_v1_only(
|
@test_util.run_in_graph_and_eager_modes
|
||||||
"Needs an API to setup multiple devices, b/124805129")
|
|
||||||
# Set up multiple devices when graph building. Before test.main() we configure
|
|
||||||
# the devices for eager execution.
|
|
||||||
@test_util.run_in_graph_and_eager_modes(
|
|
||||||
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
|
||||||
def test_checkpoint_is_sharded_by_device(self):
|
def test_checkpoint_is_sharded_by_device(self):
|
||||||
with ops.device("cpu:0"):
|
with ops.device("cpu:0"):
|
||||||
v0 = resource_variable_ops.ResourceVariable(0.)
|
v0 = resource_variable_ops.ResourceVariable(0.)
|
||||||
|
@ -99,9 +137,9 @@ class SaverTest(test.TestCase):
|
||||||
|
|
||||||
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
|
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
|
||||||
saver = functional_saver.MultiDeviceSaver(
|
saver = functional_saver.MultiDeviceSaver(
|
||||||
list(saveable_object_util.saveable_objects_for_op(v0, "v0"))
|
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
|
||||||
+ list(saveable_object_util.saveable_objects_for_op(v1, "v1"))
|
list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
|
||||||
+ list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
|
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
|
||||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
self.evaluate(saver.save(constant_op.constant(prefix)))
|
self.evaluate(saver.save(constant_op.constant(prefix)))
|
||||||
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
|
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
|
||||||
|
@ -113,8 +151,38 @@ class SaverTest(test.TestCase):
|
||||||
self.assertEqual(1., self.evaluate(v1))
|
self.assertEqual(1., self.evaluate(v1))
|
||||||
self.assertEqual(2., self.evaluate(v2))
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_checkpoint_multi_device_using_localhost(self):
|
||||||
|
with ops.device("cpu:0"):
|
||||||
|
v0 = resource_variable_ops.ResourceVariable(0.)
|
||||||
|
with ops.device("cpu:1"):
|
||||||
|
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||||
|
with ops.device("cpu:2"):
|
||||||
|
v2 = resource_variable_ops.ResourceVariable(2.)
|
||||||
|
|
||||||
class SaveableHookTest(test.TestCase):
|
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
|
||||||
|
saver = functional_saver.MultiDeviceSaver(
|
||||||
|
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
|
||||||
|
list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
|
||||||
|
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
|
||||||
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
|
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
|
||||||
|
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
|
||||||
|
self.evaluate(v0.assign(-1.))
|
||||||
|
self.evaluate(v1.assign(-1.))
|
||||||
|
self.evaluate(v2.assign(-1.))
|
||||||
|
self.evaluate(
|
||||||
|
saver.restore(constant_op.constant(prefix), self.local_options))
|
||||||
|
self.assertEqual(0., self.evaluate(v0))
|
||||||
|
self.assertEqual(1., self.evaluate(v1))
|
||||||
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
# In graph mode, verify that the save and restore ops were set to run on
|
||||||
|
# localhost.
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
for op in ops.get_default_graph().get_operations():
|
||||||
|
if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
|
||||||
|
self.assertEqual(LOCALHOST, op.device)
|
||||||
|
|
||||||
def test_callbacks_run(self):
|
def test_callbacks_run(self):
|
||||||
# Use dict because an int would be shadowed inside callback.
|
# Use dict because an int would be shadowed inside callback.
|
||||||
|
@ -144,6 +212,5 @@ class SaveableHookTest(test.TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution(
|
ops.enable_eager_execution()
|
||||||
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
@ -150,6 +150,7 @@ py_library(
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/keras:backend",
|
"//tensorflow/python/keras:backend",
|
||||||
|
"//tensorflow/python/training/saving:checkpoint_options",
|
||||||
"//tensorflow/python/training/saving:functional_saver",
|
"//tensorflow/python/training/saving:functional_saver",
|
||||||
"//tensorflow/python/training/saving:saveable_object_util",
|
"//tensorflow/python/training/saving:saveable_object_util",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
|
@ -191,6 +192,7 @@ tf_py_test(
|
||||||
"//tensorflow/python/keras:engine",
|
"//tensorflow/python/keras:engine",
|
||||||
"//tensorflow/python/keras/layers",
|
"//tensorflow/python/keras/layers",
|
||||||
"//tensorflow/python/keras/optimizer_v2",
|
"//tensorflow/python/keras/optimizer_v2",
|
||||||
|
"//tensorflow/python/training/saving:checkpoint_options",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
|
|
@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import py_checkpoint_reader
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import saver as v1_saver_lib
|
from tensorflow.python.training import saver as v1_saver_lib
|
||||||
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
|
@ -168,7 +169,7 @@ class _CheckpointRestoreCoordinator(object):
|
||||||
"""Holds the status of an object-based checkpoint load."""
|
"""Holds the status of an object-based checkpoint load."""
|
||||||
|
|
||||||
def __init__(self, object_graph_proto, save_path, save_path_tensor,
|
def __init__(self, object_graph_proto, save_path, save_path_tensor,
|
||||||
restore_op_cache, graph_view):
|
restore_op_cache, graph_view, options):
|
||||||
"""Specify the checkpoint being loaded.
|
"""Specify the checkpoint being loaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -184,7 +185,9 @@ class _CheckpointRestoreCoordinator(object):
|
||||||
`restore()` calls.
|
`restore()` calls.
|
||||||
graph_view: A graph_view_lib.ObjectGraphView object for the restored
|
graph_view: A graph_view_lib.ObjectGraphView object for the restored
|
||||||
objects.
|
objects.
|
||||||
|
options: A CheckpointOptions object.
|
||||||
"""
|
"""
|
||||||
|
self.options = options
|
||||||
self.object_graph_proto = object_graph_proto
|
self.object_graph_proto = object_graph_proto
|
||||||
self.restore_uid = ops.uid()
|
self.restore_uid = ops.uid()
|
||||||
# Maps from proto ids to lists of attributes which were in the checkpoint
|
# Maps from proto ids to lists of attributes which were in the checkpoint
|
||||||
|
@ -291,7 +294,7 @@ class _CheckpointRestoreCoordinator(object):
|
||||||
("Saveable keys changed when validating. Got back %s, was "
|
("Saveable keys changed when validating. Got back %s, was "
|
||||||
"expecting %s") % (tensor_saveables.keys(), validated_names))
|
"expecting %s") % (tensor_saveables.keys(), validated_names))
|
||||||
new_restore_ops = functional_saver.MultiDeviceSaver(
|
new_restore_ops = functional_saver.MultiDeviceSaver(
|
||||||
validated_saveables).restore(self.save_path_tensor)
|
validated_saveables).restore(self.save_path_tensor, self.options)
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
for name, restore_op in sorted(new_restore_ops.items()):
|
for name, restore_op in sorted(new_restore_ops.items()):
|
||||||
restore_ops.append(restore_op)
|
restore_ops.append(restore_op)
|
||||||
|
@ -1113,13 +1116,15 @@ class TrackableSaver(object):
|
||||||
|
|
||||||
def _save_cached_when_graph_building(self,
|
def _save_cached_when_graph_building(self,
|
||||||
file_prefix,
|
file_prefix,
|
||||||
object_graph_tensor=None):
|
object_graph_tensor,
|
||||||
|
options):
|
||||||
"""Create or retrieve save ops.
|
"""Create or retrieve save ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: The prefix for saved checkpoint files.
|
file_prefix: The prefix for saved checkpoint files.
|
||||||
object_graph_tensor: A `Tensor` to which the current object graph will be
|
object_graph_tensor: A `Tensor` to which the current object graph will be
|
||||||
fed.
|
fed.
|
||||||
|
options: `CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A two-element tuple with a filename tensor and a feed_dict of tensors to
|
A two-element tuple with a filename tensor and a feed_dict of tensors to
|
||||||
|
@ -1137,14 +1142,15 @@ class TrackableSaver(object):
|
||||||
# var_list.
|
# var_list.
|
||||||
or context.executing_eagerly() or ops.inside_function()):
|
or context.executing_eagerly() or ops.inside_function()):
|
||||||
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
|
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
|
||||||
save_op = saver.save(file_prefix)
|
save_op = saver.save(file_prefix, options=options)
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
with ops.control_dependencies([save_op]):
|
with ops.control_dependencies([save_op]):
|
||||||
self._cached_save_operation = array_ops.identity(file_prefix)
|
self._cached_save_operation = array_ops.identity(file_prefix)
|
||||||
self._last_save_object_graph = graph_proto
|
self._last_save_object_graph = graph_proto
|
||||||
return self._cached_save_operation, feed_additions
|
return self._cached_save_operation, feed_additions
|
||||||
|
|
||||||
def save(self, file_prefix, checkpoint_number=None, session=None):
|
def save(self, file_prefix, checkpoint_number=None, session=None,
|
||||||
|
options=None):
|
||||||
"""Save a training checkpoint.
|
"""Save a training checkpoint.
|
||||||
|
|
||||||
The saved checkpoint includes variables created by this object and any
|
The saved checkpoint includes variables created by this object and any
|
||||||
|
@ -1162,10 +1168,12 @@ class TrackableSaver(object):
|
||||||
session: The session to evaluate variables in. Ignored when executing
|
session: The session to evaluate variables in. Ignored when executing
|
||||||
eagerly. If not provided when graph building, the default session is
|
eagerly. If not provided when graph building, the default session is
|
||||||
used.
|
used.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint.
|
The full path to the checkpoint.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
use_session = (not context.executing_eagerly() and
|
use_session = (not context.executing_eagerly() and
|
||||||
not ops.inside_function())
|
not ops.inside_function())
|
||||||
|
@ -1189,7 +1197,7 @@ class TrackableSaver(object):
|
||||||
|
|
||||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||||
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
||||||
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
|
file_prefix_tensor, object_graph_tensor, options)
|
||||||
if new_feed_additions:
|
if new_feed_additions:
|
||||||
feed_dict.update(new_feed_additions)
|
feed_dict.update(new_feed_additions)
|
||||||
if not use_session:
|
if not use_session:
|
||||||
|
@ -1202,7 +1210,7 @@ class TrackableSaver(object):
|
||||||
else:
|
else:
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def restore(self, save_path):
|
def restore(self, save_path, options=None):
|
||||||
"""Restore a training checkpoint.
|
"""Restore a training checkpoint.
|
||||||
|
|
||||||
Restores `root_trackable` and any objects that it tracks
|
Restores `root_trackable` and any objects that it tracks
|
||||||
|
@ -1250,6 +1258,7 @@ class TrackableSaver(object):
|
||||||
object which may run initializers for objects in the dependency graph.
|
object which may run initializers for objects in the dependency graph.
|
||||||
If the checkpoint was written by the name-based
|
If the checkpoint was written by the name-based
|
||||||
`tf.compat.v1.train.Saver`, names are used to match variables.
|
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
|
@ -1260,6 +1269,7 @@ class TrackableSaver(object):
|
||||||
If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
|
If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
|
||||||
object is returned which runs restore ops from a name-based saver.
|
object is returned which runs restore ops from a name-based saver.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
if save_path is None:
|
if save_path is None:
|
||||||
return InitializationOnlyStatus(self._graph_view, ops.uid())
|
return InitializationOnlyStatus(self._graph_view, ops.uid())
|
||||||
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||||
|
@ -1304,7 +1314,8 @@ class TrackableSaver(object):
|
||||||
save_path=save_path,
|
save_path=save_path,
|
||||||
save_path_tensor=file_prefix_tensor,
|
save_path_tensor=file_prefix_tensor,
|
||||||
restore_op_cache=self._restore_op_cache,
|
restore_op_cache=self._restore_op_cache,
|
||||||
graph_view=self._graph_view)
|
graph_view=self._graph_view,
|
||||||
|
options=options)
|
||||||
base.CheckpointPosition(
|
base.CheckpointPosition(
|
||||||
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
||||||
load_status = CheckpointLoadStatus(
|
load_status = CheckpointLoadStatus(
|
||||||
|
@ -1736,6 +1747,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
checkpoint_directory = "/tmp/training_checkpoints"
|
checkpoint_directory = "/tmp/training_checkpoints"
|
||||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||||
|
|
||||||
|
# Create a Checkpoint that will manage two objects with trackable state,
|
||||||
|
# one we name "optimizer" and the other we name "model".
|
||||||
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||||
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||||
for _ in range(num_training_steps):
|
for _ in range(num_training_steps):
|
||||||
|
@ -1744,7 +1757,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||||
```
|
```
|
||||||
|
|
||||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
`Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
|
||||||
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
|
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
|
||||||
writes and
|
writes and
|
||||||
reads `variable.name` based checkpoints. Object-based checkpointing saves a
|
reads `variable.name` based checkpoints. Object-based checkpointing saves a
|
||||||
|
@ -1757,7 +1770,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
arguments to their constructors, and each dependency is given a name that is
|
arguments to their constructors, and each dependency is given a name that is
|
||||||
identical to the name of the keyword argument for which it was created.
|
identical to the name of the keyword argument for which it was created.
|
||||||
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
|
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
|
||||||
dependencies on their variables (e.g. "kernel" and "bias" for
|
dependencies on their own variables (e.g. "kernel" and "bias" for
|
||||||
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
|
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
|
||||||
dependencies easy in user-defined classes, since `Model` hooks into attribute
|
dependencies easy in user-defined classes, since `Model` hooks into attribute
|
||||||
assignment. For example:
|
assignment. For example:
|
||||||
|
@ -1840,7 +1853,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
dtype=dtypes.int64,
|
dtype=dtypes.int64,
|
||||||
trainable=False))
|
trainable=False))
|
||||||
|
|
||||||
def write(self, file_prefix):
|
def write(self, file_prefix, options=None):
|
||||||
"""Writes a training checkpoint.
|
"""Writes a training checkpoint.
|
||||||
|
|
||||||
The checkpoint includes variables created by this object and any
|
The checkpoint includes variables created by this object and any
|
||||||
|
@ -1854,14 +1867,35 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
|
|
||||||
Checkpoints written with `write` must be read with `read`.
|
Checkpoints written with `write` must be read with `read`.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
step = tf.Variable(0, name="step")
|
||||||
|
checkpoint = tf.Checkpoint(step=step)
|
||||||
|
checkpoint.write("/tmp/ckpt")
|
||||||
|
|
||||||
|
# Later, read the checkpoint with read()
|
||||||
|
checkpoint.read("/tmp/ckpt").assert_consumed()
|
||||||
|
|
||||||
|
# You can also pass options to write() and read(). For example this
|
||||||
|
# runs the IO ops on the localhost:
|
||||||
|
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
|
||||||
|
checkpoint.write("/tmp/ckpt", options=options)
|
||||||
|
|
||||||
|
# Later, read the checkpoint with read()
|
||||||
|
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A prefix to use for the checkpoint filenames
|
file_prefix: A prefix to use for the checkpoint filenames
|
||||||
(/path/to/directory/and_a_prefix).
|
(/path/to/directory/and_a_prefix).
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint (i.e. `file_prefix`).
|
The full path to the checkpoint (i.e. `file_prefix`).
|
||||||
"""
|
"""
|
||||||
output = self._saver.save(file_prefix=file_prefix)
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
|
output = self._saver.save(file_prefix=file_prefix, options=options)
|
||||||
if tensor_util.is_tensor(output):
|
if tensor_util.is_tensor(output):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return compat.as_str(output.numpy())
|
return compat.as_str(output.numpy())
|
||||||
|
@ -1884,7 +1918,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
self._maybe_create_save_counter()
|
self._maybe_create_save_counter()
|
||||||
return self._save_counter
|
return self._save_counter
|
||||||
|
|
||||||
def save(self, file_prefix):
|
def save(self, file_prefix, options=None):
|
||||||
"""Saves a training checkpoint and provides basic checkpoint management.
|
"""Saves a training checkpoint and provides basic checkpoint management.
|
||||||
|
|
||||||
The saved checkpoint includes variables created by this object and any
|
The saved checkpoint includes variables created by this object and any
|
||||||
|
@ -1898,14 +1932,33 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
provided by other utilities which also wrap `write` and `read`.
|
provided by other utilities which also wrap `write` and `read`.
|
||||||
(`tf.train.CheckpointManager` for example).
|
(`tf.train.CheckpointManager` for example).
|
||||||
|
|
||||||
|
```
|
||||||
|
step = tf.Variable(0, name="step")
|
||||||
|
checkpoint = tf.Checkpoint(step=step)
|
||||||
|
checkpoint.save("/tmp/ckpt")
|
||||||
|
|
||||||
|
# Later, read the checkpoint with restore()
|
||||||
|
checkpoint.restore("/tmp/ckpt").assert_consumed()
|
||||||
|
|
||||||
|
# You can also pass options to save() and restore(). For example this
|
||||||
|
# runs the IO ops on the localhost:
|
||||||
|
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
|
||||||
|
checkpoint.save("/tmp/ckpt", options=options)
|
||||||
|
|
||||||
|
# Later, read the checkpoint with restore()
|
||||||
|
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A prefix to use for the checkpoint filenames
|
file_prefix: A prefix to use for the checkpoint filenames
|
||||||
(/path/to/directory/and_a_prefix). Names are generated based on this
|
(/path/to/directory/and_a_prefix). Names are generated based on this
|
||||||
prefix and `Checkpoint.save_counter`.
|
prefix and `Checkpoint.save_counter`.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint.
|
The full path to the checkpoint.
|
||||||
"""
|
"""
|
||||||
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
graph_building = not context.executing_eagerly()
|
graph_building = not context.executing_eagerly()
|
||||||
if graph_building:
|
if graph_building:
|
||||||
if ops.inside_function():
|
if ops.inside_function():
|
||||||
|
@ -1931,7 +1984,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
checkpoint_number = session.run(self._save_assign_op)
|
checkpoint_number = session.run(self._save_assign_op)
|
||||||
else:
|
else:
|
||||||
checkpoint_number = assign_op.numpy()
|
checkpoint_number = assign_op.numpy()
|
||||||
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number))
|
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
|
||||||
|
options=options)
|
||||||
checkpoint_management.update_checkpoint_state_internal(
|
checkpoint_management.update_checkpoint_state_internal(
|
||||||
save_dir=os.path.dirname(file_prefix),
|
save_dir=os.path.dirname(file_prefix),
|
||||||
model_checkpoint_path=file_path,
|
model_checkpoint_path=file_path,
|
||||||
|
@ -1939,7 +1993,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
save_relative_paths=True)
|
save_relative_paths=True)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
def read(self, save_path):
|
def read(self, save_path, options=None):
|
||||||
"""Read a training checkpoint written with `write`.
|
"""Read a training checkpoint written with `write`.
|
||||||
|
|
||||||
Reads this `Checkpoint` and any objects it depends on.
|
Reads this `Checkpoint` and any objects it depends on.
|
||||||
|
@ -1962,18 +2016,25 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
# Later, load the checkpoint with read()
|
# Later, load the checkpoint with read()
|
||||||
# With restore() assert_consumed() would have failed.
|
# With restore() assert_consumed() would have failed.
|
||||||
checkpoint.read(path).assert_consumed()
|
checkpoint.read(path).assert_consumed()
|
||||||
|
|
||||||
|
# You can also pass options to restore(). For example this
|
||||||
|
# runs the IO ops on the localhost:
|
||||||
|
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
|
||||||
|
checkpoint.read(path, options=options)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_path: The path to the checkpoint as returned by `write`.
|
save_path: The path to the checkpoint as returned by `write`.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
status of a checkpoint restoration. See `restore` for details.
|
status of a checkpoint restoration. See `restore` for details.
|
||||||
"""
|
"""
|
||||||
return self._saver.restore(save_path=save_path)
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
|
return self._saver.restore(save_path=save_path, options=options)
|
||||||
|
|
||||||
def restore(self, save_path):
|
def restore(self, save_path, options=None):
|
||||||
"""Restore a training checkpoint.
|
"""Restore a training checkpoint.
|
||||||
|
|
||||||
Restores this `Checkpoint` and any objects it depends on.
|
Restores this `Checkpoint` and any objects it depends on.
|
||||||
|
@ -1995,6 +2056,10 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
```python
|
```python
|
||||||
checkpoint = tf.train.Checkpoint( ... )
|
checkpoint = tf.train.Checkpoint( ... )
|
||||||
checkpoint.restore(path).assert_consumed()
|
checkpoint.restore(path).assert_consumed()
|
||||||
|
|
||||||
|
# You can additionally pass options to restore():
|
||||||
|
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
|
||||||
|
checkpoint.restore(path, options=options).assert_consumed()
|
||||||
```
|
```
|
||||||
|
|
||||||
An exception will be raised if any Python objects in the dependency graph
|
An exception will be raised if any Python objects in the dependency graph
|
||||||
|
@ -2011,6 +2076,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
`tf.train.latest_checkpoint`. If the checkpoint was written by the
|
`tf.train.latest_checkpoint`. If the checkpoint was written by the
|
||||||
name-based `tf.compat.v1.train.Saver`, names are used to match
|
name-based `tf.compat.v1.train.Saver`, names are used to match
|
||||||
variables.
|
variables.
|
||||||
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
|
@ -2049,7 +2115,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||||
checkpoint file or object when the `Checkpoint` object is deleted
|
checkpoint file or object when the `Checkpoint` object is deleted
|
||||||
(often at program shutdown).
|
(often at program shutdown).
|
||||||
"""
|
"""
|
||||||
status = self.read(save_path)
|
status = self.read(save_path, options=options)
|
||||||
# Create the save counter now so it gets initialized with other variables
|
# Create the save counter now so it gets initialized with other variables
|
||||||
# when graph building. Creating it earlier would lead to errors when using,
|
# when graph building. Creating it earlier would lead to errors when using,
|
||||||
# say, train.Saver() to save the model before initializing it.
|
# say, train.Saver() to save the model before initializing it.
|
||||||
|
|
|
@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
|
@ -409,6 +410,28 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||||
del ckpt
|
del ckpt
|
||||||
status.assert_consumed()
|
status.assert_consumed()
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testPassingCheckpointOptions(self):
|
||||||
|
localhost = "/job:localhost/device:CPU:0"
|
||||||
|
options = checkpoint_options.CheckpointOptions(
|
||||||
|
experimental_io_device=localhost)
|
||||||
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
|
v = variable_scope.get_variable(name="v", initializer=0.)
|
||||||
|
self.evaluate(v.initializer)
|
||||||
|
ckpt = trackable_utils.Checkpoint(v=v)
|
||||||
|
self.evaluate(trackable_utils.gather_initializers(ckpt))
|
||||||
|
save_path = ckpt.save(file_prefix=prefix, options=options)
|
||||||
|
status = ckpt.restore(save_path=save_path, options=options)
|
||||||
|
del ckpt
|
||||||
|
status.assert_consumed()
|
||||||
|
|
||||||
|
# In graph mode, verify that the save and restore ops were set to run on
|
||||||
|
# localhost.
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
for op in ops.get_default_graph().get_operations():
|
||||||
|
if op.type in ("SaveV2", "RestoreV2"):
|
||||||
|
self.assertEqual(localhost, op.device)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
model = MyModel()
|
model = MyModel()
|
||||||
|
|
|
@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
|
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_io_device"
|
||||||
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "function_aliases"
|
name: "function_aliases"
|
||||||
mtype: "<type \'member_descriptor\'>"
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
@ -16,6 +20,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
path: "tensorflow.train.CheckpointOptions"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_io_device"
|
||||||
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
|
@ -28,6 +28,10 @@ tf_module {
|
||||||
name: "CheckpointManager"
|
name: "CheckpointManager"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "CheckpointOptions"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "CheckpointSaverHook"
|
name: "CheckpointSaverHook"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
|
|
@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
|
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_io_device"
|
||||||
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "function_aliases"
|
name: "function_aliases"
|
||||||
mtype: "<type \'member_descriptor\'>"
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
@ -16,6 +20,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
path: "tensorflow.train.CheckpointOptions"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_io_device"
|
||||||
|
mtype: "<type \'member_descriptor\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,18 +14,18 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "read"
|
name: "read"
|
||||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "restore"
|
name: "restore"
|
||||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "save"
|
name: "save"
|
||||||
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "write"
|
name: "write"
|
||||||
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,10 @@ tf_module {
|
||||||
name: "CheckpointManager"
|
name: "CheckpointManager"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "CheckpointOptions"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "ClusterDef"
|
name: "ClusterDef"
|
||||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||||
|
|
Loading…
Reference in New Issue