Add an option to choose the I/O Device for saving and loading models.

This option enables saving and restoring models to or from filesystems only
accessible from the localhost when using multiple devices.

The option is available to
 - Save models: tf.saved_model.save()
 - Checkpoints: tf.Checkpoint()

PiperOrigin-RevId: 307858098
Change-Id: I4cd0a81424e306f0eac40bfb30d5067dfc02d1be
This commit is contained in:
A. Unique TensorFlower 2020-04-22 11:23:53 -07:00 committed by TensorFlower Gardener
parent 0493a020d4
commit e853835634
18 changed files with 365 additions and 60 deletions

View File

@ -310,6 +310,7 @@ py_library(
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function", "//tensorflow/python/eager:function",
"//tensorflow/python/training/saving:checkpoint_options",
"//tensorflow/python/training/saving:functional_saver", "//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base", "//tensorflow/python/training/tracking:base",

View File

@ -52,6 +52,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import graph_view
@ -941,6 +942,7 @@ def save(obj, export_dir, signatures=None, options=None):
May not be called from within a function body. May not be called from within a function body.
@end_compatibility @end_compatibility
""" """
options = options or save_options.SaveOptions()
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than # compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly. # making a SavedModel proto and writing it directly.
@ -954,7 +956,10 @@ def save(obj, export_dir, signatures=None, options=None):
# Write the checkpoint, copy assets into the assets directory, and write out # Write the checkpoint, copy assets into the assets directory, and write out
# the SavedModel proto itself. # the SavedModel proto itself.
utils_impl.get_or_create_variables_dir(export_dir) utils_impl.get_or_create_variables_dir(export_dir)
object_saver.save(utils_impl.get_variables_path(export_dir)) ckpt_options = checkpoint_options.CheckpointOptions(
experimental_io_device=options.experimental_io_device)
object_saver.save(utils_impl.get_variables_path(export_dir),
options=ckpt_options)
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
export_dir) export_dir)
# Note that this needs to be the last file operation when saving the # Note that this needs to be the last file operation when saving the
@ -976,6 +981,7 @@ def save(obj, export_dir, signatures=None, options=None):
def export_meta_graph(obj, filename, signatures=None, options=None): def export_meta_graph(obj, filename, signatures=None, options=None):
"""Exports the MetaGraph proto to a file.""" """Exports the MetaGraph proto to a file."""
options = options or save_options.SaveOptions()
export_dir = os.path.dirname(filename) export_dir = os.path.dirname(filename)
meta_graph_def, exported_graph, _, _ = _build_meta_graph( meta_graph_def, exported_graph, _, _ = _build_meta_graph(
obj, export_dir, signatures, options) obj, export_dir, signatures, options)
@ -1001,7 +1007,6 @@ def _build_meta_graph(obj, export_dir, signatures, options,
if not isinstance(obj, base.Trackable): if not isinstance(obj, base.Trackable):
raise ValueError( raise ValueError(
"Expected a Trackable object for export, got {}.".format(obj)) "Expected a Trackable object for export, got {}.".format(obj))
options = options or save_options.SaveOptions()
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
checkpoint_graph_view = _AugmentedGraphView(obj) checkpoint_graph_view = _AugmentedGraphView(obj)

View File

@ -33,12 +33,14 @@ class SaveOptions(object):
""" """
# Define object attributes in __slots__ for improved memory and performance. # Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases") __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
"experimental_io_device")
def __init__(self, def __init__(self,
namespace_whitelist=None, namespace_whitelist=None,
save_debug_info=False, save_debug_info=False,
function_aliases=None): function_aliases=None,
experimental_io_device=None):
"""Creates an object that stores options for SavedModel saving. """Creates an object that stores options for SavedModel saving.
Args: Args:
@ -46,16 +48,15 @@ class SaveOptions(object):
when saving a model. Saving an object that uses namespaced ops must when saving a model. Saving an object that uses namespaced ops must
explicitly add all namespaces to the whitelist. The namespaced ops must explicitly add all namespaces to the whitelist. The namespaced ops must
be registered into the framework when loading the SavedModel. be registered into the framework when loading the SavedModel.
save_debug_info: Boolean indicating whether debug information is saved. save_debug_info: Boolean indicating whether debug information is saved. If
If True, then a debug/saved_model_debug_info.pb file will be written True, then a debug/saved_model_debug_info.pb file will be written with
with the contents of a GraphDebugInfo binary protocol buffer containing the contents of a GraphDebugInfo binary protocol buffer containing stack
stack trace information for all ops and functions that are saved. trace information for all ops and functions that are saved.
function_aliases: Python dict. Mapping from string to object returned by function_aliases: Python dict. Mapping from string to object returned by
@tf.function. @tf.function. A single tf.function can generate many ConcreteFunctions.
A single tf.function can generate many ConcreteFunctions. If a If a downstream tool wants to refer to all concrete functions generated
downstream tool wants to refer to all concrete functions generated by a by a single tf.function you can use the `function_aliases` argument to
single tf.function you can use the `function_aliases` argument to store store a map from the alias name to all concrete function names.
a map from the alias name to all concrete function names.
E.g. E.g.
```python ```python
class MyModel: class MyModel:
@ -77,11 +78,21 @@ class SaveOptions(object):
}) })
tf.saved_model.save(model, export_dir, signatures, options) tf.saved_model.save(model, export_dir, signatures, options)
``` ```
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to save to a local directory,
such as "/tmp" when running in a distributed setting. In that case pass
a device for the host where the "/tmp" directory is accessible.
""" """
self.namespace_whitelist = _validate_namespace_whitelist( self.namespace_whitelist = _validate_namespace_whitelist(
namespace_whitelist) namespace_whitelist)
self.save_debug_info = save_debug_info self.save_debug_info = save_debug_info
self.function_aliases = function_aliases if function_aliases else dict() self.function_aliases = function_aliases if function_aliases else dict()
self.experimental_io_device = experimental_io_device
def _validate_namespace_whitelist(namespace_whitelist): def _validate_namespace_whitelist(namespace_whitelist):

View File

@ -577,6 +577,12 @@ class SavingOptionsTest(test.TestCase):
self.assertEqual(function_cache[0].name.decode("utf-8"), self.assertEqual(function_cache[0].name.decode("utf-8"),
list(function_aliases.keys())[0]) list(function_aliases.keys())[0])
def test_accepts_io_device(self):
options = save_options.SaveOptions()
self.assertEqual(None, options.experimental_io_device)
options = save_options.SaveOptions(experimental_io_device="/job:localhost")
self.assertEqual("/job:localhost", options.experimental_io_device)
class AssetTests(test.TestCase): class AssetTests(test.TestCase):

View File

@ -12,11 +12,20 @@ package(
exports_files(["LICENSE"]) exports_files(["LICENSE"])
py_library(
name = "checkpoint_options",
srcs = ["checkpoint_options.py"],
deps = [
"//tensorflow/python:tf_export",
],
)
py_library( py_library(
name = "functional_saver", name = "functional_saver",
srcs = ["functional_saver.py"], srcs = ["functional_saver.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":checkpoint_options",
":saveable_hook", ":saveable_hook",
":saveable_object", ":saveable_object",
":saveable_object_util", ":saveable_object_util",
@ -31,6 +40,7 @@ cuda_py_test(
"functional_saver_test.py", "functional_saver_test.py",
], ],
deps = [ deps = [
":checkpoint_options",
":functional_saver", ":functional_saver",
":saveable_hook", ":saveable_hook",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",

View File

@ -0,0 +1,58 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving Checkpoints."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
@tf_export("train.CheckpointOptions")
class CheckpointOptions(object):
"""Options for constructing a Checkpoint.
Used as the `_options` argument to the `tf.Checkpoint` constructor to adjust
how variables are saved.
Example: Run IO ops on "localhost" while saving a checkpoint:
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)
```
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("experimental_io_device",)
def __init__(self, experimental_io_device=None):
"""Creates an object that stores options for a Checkpoint.
Args:
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to save to a local directory,
such as "/tmp" when running in a distributed setting. In that case pass
a device for the host where the "/tmp" directory is accessible.
"""
self.experimental_io_device = experimental_io_device

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_hook from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.saving import saveable_object_util
@ -52,15 +53,17 @@ class _SingleDeviceSaver(object):
"Expected a list of SaveableObjects, got %s." % (saveable,)) "Expected a list of SaveableObjects, got %s." % (saveable,))
self._saveable_objects = saveable_objects self._saveable_objects = saveable_objects
def save(self, file_prefix): def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`. """Save the saveable objects to a checkpoint with `file_prefix`.
Args: Args:
file_prefix: A string or scalar string Tensor containing the prefix to file_prefix: A string or scalar string Tensor containing the prefix to
save under. save under.
options: Optional `CheckpointOptions` object.
Returns: Returns:
An `Operation`, or None when executing eagerly. An `Operation`, or None when executing eagerly.
""" """
options = options or checkpoint_options.CheckpointOptions()
tensor_names = [] tensor_names = []
tensors = [] tensors = []
tensor_slices = [] tensor_slices = []
@ -69,19 +72,22 @@ class _SingleDeviceSaver(object):
tensor_names.append(spec.name) tensor_names.append(spec.name)
tensors.append(spec.tensor) tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec) tensor_slices.append(spec.slice_spec)
with ops.device("cpu:0"): save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors) return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
def restore(self, file_prefix): def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`. """Restore the saveable objects from a checkpoint with `file_prefix`.
Args: Args:
file_prefix: A string or scalar string Tensor containing the prefix for file_prefix: A string or scalar string Tensor containing the prefix for
files to read from. files to read from.
options: Optional `CheckpointOptions` object.
Returns: Returns:
A dictionary mapping from SaveableObject names to restore operations. A dictionary mapping from SaveableObject names to restore operations.
""" """
options = options or checkpoint_options.CheckpointOptions()
restore_specs = [] restore_specs = []
tensor_structure = [] tensor_structure = []
for saveable in self._saveable_objects: for saveable in self._saveable_objects:
@ -91,7 +97,8 @@ class _SingleDeviceSaver(object):
saveable_tensor_structure.append(spec.name) saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
with ops.device("cpu:0"): restore_device = options.experimental_io_device or "cpu:0"
with ops.device(restore_device):
restored_tensors = io_ops.restore_v2( restored_tensors = io_ops.restore_v2(
file_prefix, tensor_names, tensor_slices, tensor_dtypes) file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as( structured_restored_tensors = nest.pack_sequence_as(
@ -190,15 +197,17 @@ class MultiDeviceSaver(object):
with ops.control_dependencies(restore_ops.values()): with ops.control_dependencies(restore_ops.values()):
return array_ops.identity(file_prefix) return array_ops.identity(file_prefix)
def save(self, file_prefix): def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`. """Save the saveable objects to a checkpoint with `file_prefix`.
Args: Args:
file_prefix: A string or scalar string Tensor containing the prefix to file_prefix: A string or scalar string Tensor containing the prefix to
save under. save under.
options: Optional `CheckpointOptions` object.
Returns: Returns:
An `Operation`, or None when executing eagerly. An `Operation`, or None when executing eagerly.
""" """
options = options or checkpoint_options.CheckpointOptions()
for callback in self._before_save_callbacks: for callback in self._before_save_callbacks:
callback() callback()
@ -253,32 +262,37 @@ class MultiDeviceSaver(object):
with ops.device(device): with ops.device(device):
# _SingleDeviceSaver will use the CPU device when necessary, but initial # _SingleDeviceSaver will use the CPU device when necessary, but initial
# read operations should be placed on the SaveableObject's device. # read operations should be placed on the SaveableObject's device.
sharded_saves.append(saver.save(shard_prefix)) sharded_saves.append(saver.save(shard_prefix, options))
with ops.control_dependencies(sharded_saves): with ops.control_dependencies(sharded_saves):
# Co-locates the merge step with the last device. # Merge on the io_device if specified, otherwise co-locates the merge op
with ops.device(saveable_object_util.set_cpu0(last_device)): # with the last device used.
merge_device = (options.experimental_io_device or
saveable_object_util.set_cpu0(last_device))
with ops.device(merge_device):
# V2 format write path consists of a metadata merge step. Once merged, # V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp". # attempts to delete the temporary directory, "<user-fed prefix>_temp".
return gen_io_ops.merge_v2_checkpoints( return gen_io_ops.merge_v2_checkpoints(
sharded_prefixes, file_prefix, delete_old_dirs=True) sharded_prefixes, file_prefix, delete_old_dirs=True)
def restore(self, file_prefix): def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`. """Restore the saveable objects from a checkpoint with `file_prefix`.
Args: Args:
file_prefix: A string or scalar string Tensor containing the prefix for file_prefix: A string or scalar string Tensor containing the prefix for
files to read from. files to read from.
options: Optional `CheckpointOptions` object.
Returns: Returns:
A dictionary mapping from SaveableObject names to restore operations. A dictionary mapping from SaveableObject names to restore operations.
""" """
options = options or checkpoint_options.CheckpointOptions()
restore_ops = {} restore_ops = {}
# Sort by device name to avoid propagating non-deterministic dictionary # Sort by device name to avoid propagating non-deterministic dictionary
# ordering in some Python versions. # ordering in some Python versions.
for device, saver in sorted(self._single_device_savers.items()): for device, saver in sorted(self._single_device_savers.items()):
with ops.device(device): with ops.device(device):
restore_ops.update(saver.restore(file_prefix)) restore_ops.update(saver.restore(file_prefix, options))
for callback in self._after_restore_callbacks: for callback in self._after_restore_callbacks:
callback() callback()

View File

@ -20,21 +20,37 @@ from __future__ import print_function
import os import os
from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.eager import wrap_function from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_hook from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.saving import saveable_object_util
LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0"
class SaverTest(test.TestCase): class SaverTest(test.TestCase):
def setUp(self):
super(SaverTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 3 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
self.local_options = checkpoint_options.CheckpointOptions(
experimental_io_device=LOCALHOST)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def test_resource_variable(self): def test_resource_variable(self):
v1 = resource_variable_ops.ResourceVariable(2.) v1 = resource_variable_ops.ResourceVariable(2.)
@ -55,6 +71,33 @@ class SaverTest(test.TestCase):
self.evaluate(second_saver.restore(prefix)) self.evaluate(second_saver.restore(prefix))
self.assertEqual(2., self.evaluate(v2)) self.assertEqual(2., self.evaluate(v2))
@test_util.run_in_graph_and_eager_modes
def test_resource_variable_use_localhost(self):
v1 = resource_variable_ops.ResourceVariable(2.)
self.evaluate(v1.initializer)
saver = functional_saver._SingleDeviceSaver(
saveable_object_util.saveable_objects_for_op(v1, "x"))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
self.assertEqual(2, len(gfile.Glob(prefix + "*")))
self.evaluate(v1.assign(1.))
self.evaluate(saver.restore(prefix, self.local_options))
self.assertEqual(2., self.evaluate(v1))
v2 = resource_variable_ops.ResourceVariable(3.)
self.evaluate(v2.initializer)
second_saver = functional_saver._SingleDeviceSaver(
saveable_object_util.saveable_objects_for_op(v2, "x"))
self.evaluate(second_saver.restore(prefix, self.local_options))
self.assertEqual(2., self.evaluate(v2))
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2"):
self.assertEqual(LOCALHOST, op.device)
def test_to_proto(self): def test_to_proto(self):
v1 = resource_variable_ops.ResourceVariable(2.) v1 = resource_variable_ops.ResourceVariable(2.)
saver = functional_saver.MultiDeviceSaver( saver = functional_saver.MultiDeviceSaver(
@ -83,12 +126,7 @@ class SaverTest(test.TestCase):
second_saver.restore(save_path) second_saver.restore(save_path)
self.assertEqual(2., self.evaluate(v2)) self.assertEqual(2., self.evaluate(v2))
@test_util.run_v1_only( @test_util.run_in_graph_and_eager_modes
"Needs an API to setup multiple devices, b/124805129")
# Set up multiple devices when graph building. Before test.main() we configure
# the devices for eager execution.
@test_util.run_in_graph_and_eager_modes(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
def test_checkpoint_is_sharded_by_device(self): def test_checkpoint_is_sharded_by_device(self):
with ops.device("cpu:0"): with ops.device("cpu:0"):
v0 = resource_variable_ops.ResourceVariable(0.) v0 = resource_variable_ops.ResourceVariable(0.)
@ -99,9 +137,9 @@ class SaverTest(test.TestCase):
self.evaluate([v0.initializer, v1.initializer, v2.initializer]) self.evaluate([v0.initializer, v1.initializer, v2.initializer])
saver = functional_saver.MultiDeviceSaver( saver = functional_saver.MultiDeviceSaver(
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
+ list(saveable_object_util.saveable_objects_for_op(v1, "v1")) list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
+ list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
prefix = os.path.join(self.get_temp_dir(), "ckpt") prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix))) self.evaluate(saver.save(constant_op.constant(prefix)))
self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.assertEqual(4, len(gfile.Glob(prefix + "*")))
@ -113,8 +151,38 @@ class SaverTest(test.TestCase):
self.assertEqual(1., self.evaluate(v1)) self.assertEqual(1., self.evaluate(v1))
self.assertEqual(2., self.evaluate(v2)) self.assertEqual(2., self.evaluate(v2))
@test_util.run_in_graph_and_eager_modes
def test_checkpoint_multi_device_using_localhost(self):
with ops.device("cpu:0"):
v0 = resource_variable_ops.ResourceVariable(0.)
with ops.device("cpu:1"):
v1 = resource_variable_ops.ResourceVariable(1.)
with ops.device("cpu:2"):
v2 = resource_variable_ops.ResourceVariable(2.)
class SaveableHookTest(test.TestCase): self.evaluate([v0.initializer, v1.initializer, v2.initializer])
saver = functional_saver.MultiDeviceSaver(
list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
self.evaluate(v0.assign(-1.))
self.evaluate(v1.assign(-1.))
self.evaluate(v2.assign(-1.))
self.evaluate(
saver.restore(constant_op.constant(prefix), self.local_options))
self.assertEqual(0., self.evaluate(v0))
self.assertEqual(1., self.evaluate(v1))
self.assertEqual(2., self.evaluate(v2))
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
self.assertEqual(LOCALHOST, op.device)
def test_callbacks_run(self): def test_callbacks_run(self):
# Use dict because an int would be shadowed inside callback. # Use dict because an int would be shadowed inside callback.
@ -144,6 +212,5 @@ class SaveableHookTest(test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
ops.enable_eager_execution( ops.enable_eager_execution()
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
test.main() test.main()

View File

@ -150,6 +150,7 @@ py_library(
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/keras:backend", "//tensorflow/python/keras:backend",
"//tensorflow/python/training/saving:checkpoint_options",
"//tensorflow/python/training/saving:functional_saver", "//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/training/saving:saveable_object_util",
"@six_archive//:six", "@six_archive//:six",
@ -191,6 +192,7 @@ tf_py_test(
"//tensorflow/python/keras:engine", "//tensorflow/python/keras:engine",
"//tensorflow/python/keras/layers", "//tensorflow/python/keras/layers",
"//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/training/saving:checkpoint_options",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
"@six_archive//:six", "@six_archive//:six",
], ],

View File

@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as v1_saver_lib from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
@ -168,7 +169,7 @@ class _CheckpointRestoreCoordinator(object):
"""Holds the status of an object-based checkpoint load.""" """Holds the status of an object-based checkpoint load."""
def __init__(self, object_graph_proto, save_path, save_path_tensor, def __init__(self, object_graph_proto, save_path, save_path_tensor,
restore_op_cache, graph_view): restore_op_cache, graph_view, options):
"""Specify the checkpoint being loaded. """Specify the checkpoint being loaded.
Args: Args:
@ -184,7 +185,9 @@ class _CheckpointRestoreCoordinator(object):
`restore()` calls. `restore()` calls.
graph_view: A graph_view_lib.ObjectGraphView object for the restored graph_view: A graph_view_lib.ObjectGraphView object for the restored
objects. objects.
options: A CheckpointOptions object.
""" """
self.options = options
self.object_graph_proto = object_graph_proto self.object_graph_proto = object_graph_proto
self.restore_uid = ops.uid() self.restore_uid = ops.uid()
# Maps from proto ids to lists of attributes which were in the checkpoint # Maps from proto ids to lists of attributes which were in the checkpoint
@ -291,7 +294,7 @@ class _CheckpointRestoreCoordinator(object):
("Saveable keys changed when validating. Got back %s, was " ("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (tensor_saveables.keys(), validated_names)) "expecting %s") % (tensor_saveables.keys(), validated_names))
new_restore_ops = functional_saver.MultiDeviceSaver( new_restore_ops = functional_saver.MultiDeviceSaver(
validated_saveables).restore(self.save_path_tensor) validated_saveables).restore(self.save_path_tensor, self.options)
if not context.executing_eagerly(): if not context.executing_eagerly():
for name, restore_op in sorted(new_restore_ops.items()): for name, restore_op in sorted(new_restore_ops.items()):
restore_ops.append(restore_op) restore_ops.append(restore_op)
@ -1113,13 +1116,15 @@ class TrackableSaver(object):
def _save_cached_when_graph_building(self, def _save_cached_when_graph_building(self,
file_prefix, file_prefix,
object_graph_tensor=None): object_graph_tensor,
options):
"""Create or retrieve save ops. """Create or retrieve save ops.
Args: Args:
file_prefix: The prefix for saved checkpoint files. file_prefix: The prefix for saved checkpoint files.
object_graph_tensor: A `Tensor` to which the current object graph will be object_graph_tensor: A `Tensor` to which the current object graph will be
fed. fed.
options: `CheckpointOptions` object.
Returns: Returns:
A two-element tuple with a filename tensor and a feed_dict of tensors to A two-element tuple with a filename tensor and a feed_dict of tensors to
@ -1137,14 +1142,15 @@ class TrackableSaver(object):
# var_list. # var_list.
or context.executing_eagerly() or ops.inside_function()): or context.executing_eagerly() or ops.inside_function()):
saver = functional_saver.MultiDeviceSaver(named_saveable_objects) saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
save_op = saver.save(file_prefix) save_op = saver.save(file_prefix, options=options)
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
with ops.control_dependencies([save_op]): with ops.control_dependencies([save_op]):
self._cached_save_operation = array_ops.identity(file_prefix) self._cached_save_operation = array_ops.identity(file_prefix)
self._last_save_object_graph = graph_proto self._last_save_object_graph = graph_proto
return self._cached_save_operation, feed_additions return self._cached_save_operation, feed_additions
def save(self, file_prefix, checkpoint_number=None, session=None): def save(self, file_prefix, checkpoint_number=None, session=None,
options=None):
"""Save a training checkpoint. """Save a training checkpoint.
The saved checkpoint includes variables created by this object and any The saved checkpoint includes variables created by this object and any
@ -1162,10 +1168,12 @@ class TrackableSaver(object):
session: The session to evaluate variables in. Ignored when executing session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is eagerly. If not provided when graph building, the default session is
used. used.
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
The full path to the checkpoint. The full path to the checkpoint.
""" """
options = options or checkpoint_options.CheckpointOptions()
feed_dict = {} feed_dict = {}
use_session = (not context.executing_eagerly() and use_session = (not context.executing_eagerly() and
not ops.inside_function()) not ops.inside_function())
@ -1189,7 +1197,7 @@ class TrackableSaver(object):
file_io.recursive_create_dir(os.path.dirname(file_prefix)) file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building( save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor) file_prefix_tensor, object_graph_tensor, options)
if new_feed_additions: if new_feed_additions:
feed_dict.update(new_feed_additions) feed_dict.update(new_feed_additions)
if not use_session: if not use_session:
@ -1202,7 +1210,7 @@ class TrackableSaver(object):
else: else:
return save_path return save_path
def restore(self, save_path): def restore(self, save_path, options=None):
"""Restore a training checkpoint. """Restore a training checkpoint.
Restores `root_trackable` and any objects that it tracks Restores `root_trackable` and any objects that it tracks
@ -1250,6 +1258,7 @@ class TrackableSaver(object):
object which may run initializers for objects in the dependency graph. object which may run initializers for objects in the dependency graph.
If the checkpoint was written by the name-based If the checkpoint was written by the name-based
`tf.compat.v1.train.Saver`, names are used to match variables. `tf.compat.v1.train.Saver`, names are used to match variables.
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
@ -1260,6 +1269,7 @@ class TrackableSaver(object):
If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
object is returned which runs restore ops from a name-based saver. object is returned which runs restore ops from a name-based saver.
""" """
options = options or checkpoint_options.CheckpointOptions()
if save_path is None: if save_path is None:
return InitializationOnlyStatus(self._graph_view, ops.uid()) return InitializationOnlyStatus(self._graph_view, ops.uid())
reader = py_checkpoint_reader.NewCheckpointReader(save_path) reader = py_checkpoint_reader.NewCheckpointReader(save_path)
@ -1304,7 +1314,8 @@ class TrackableSaver(object):
save_path=save_path, save_path=save_path,
save_path_tensor=file_prefix_tensor, save_path_tensor=file_prefix_tensor,
restore_op_cache=self._restore_op_cache, restore_op_cache=self._restore_op_cache,
graph_view=self._graph_view) graph_view=self._graph_view,
options=options)
base.CheckpointPosition( base.CheckpointPosition(
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
load_status = CheckpointLoadStatus( load_status = CheckpointLoadStatus(
@ -1736,6 +1747,8 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint_directory = "/tmp/training_checkpoints" checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps): for _ in range(num_training_steps):
@ -1744,7 +1757,7 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint.save(file_prefix=checkpoint_prefix) checkpoint.save(file_prefix=checkpoint_prefix)
``` ```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
writes and writes and
reads `variable.name` based checkpoints. Object-based checkpointing saves a reads `variable.name` based checkpoints. Object-based checkpointing saves a
@ -1757,7 +1770,7 @@ class Checkpoint(tracking.AutoTrackable):
arguments to their constructors, and each dependency is given a name that is arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created. identical to the name of the keyword argument for which it was created.
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
dependencies on their variables (e.g. "kernel" and "bias" for dependencies on their own variables (e.g. "kernel" and "bias" for
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
dependencies easy in user-defined classes, since `Model` hooks into attribute dependencies easy in user-defined classes, since `Model` hooks into attribute
assignment. For example: assignment. For example:
@ -1840,7 +1853,7 @@ class Checkpoint(tracking.AutoTrackable):
dtype=dtypes.int64, dtype=dtypes.int64,
trainable=False)) trainable=False))
def write(self, file_prefix): def write(self, file_prefix, options=None):
"""Writes a training checkpoint. """Writes a training checkpoint.
The checkpoint includes variables created by this object and any The checkpoint includes variables created by this object and any
@ -1854,14 +1867,35 @@ class Checkpoint(tracking.AutoTrackable):
Checkpoints written with `write` must be read with `read`. Checkpoints written with `write` must be read with `read`.
Example usage:
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.write("/tmp/ckpt")
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt").assert_consumed()
# You can also pass options to write() and read(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("/tmp/ckpt", options=options)
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
```
Args: Args:
file_prefix: A prefix to use for the checkpoint filenames file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix). (/path/to/directory/and_a_prefix).
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
The full path to the checkpoint (i.e. `file_prefix`). The full path to the checkpoint (i.e. `file_prefix`).
""" """
output = self._saver.save(file_prefix=file_prefix) options = options or checkpoint_options.CheckpointOptions()
output = self._saver.save(file_prefix=file_prefix, options=options)
if tensor_util.is_tensor(output): if tensor_util.is_tensor(output):
if context.executing_eagerly(): if context.executing_eagerly():
return compat.as_str(output.numpy()) return compat.as_str(output.numpy())
@ -1884,7 +1918,7 @@ class Checkpoint(tracking.AutoTrackable):
self._maybe_create_save_counter() self._maybe_create_save_counter()
return self._save_counter return self._save_counter
def save(self, file_prefix): def save(self, file_prefix, options=None):
"""Saves a training checkpoint and provides basic checkpoint management. """Saves a training checkpoint and provides basic checkpoint management.
The saved checkpoint includes variables created by this object and any The saved checkpoint includes variables created by this object and any
@ -1898,14 +1932,33 @@ class Checkpoint(tracking.AutoTrackable):
provided by other utilities which also wrap `write` and `read`. provided by other utilities which also wrap `write` and `read`.
(`tf.train.CheckpointManager` for example). (`tf.train.CheckpointManager` for example).
```
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.save("/tmp/ckpt")
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt").assert_consumed()
# You can also pass options to save() and restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
```
Args: Args:
file_prefix: A prefix to use for the checkpoint filenames file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix). Names are generated based on this (/path/to/directory/and_a_prefix). Names are generated based on this
prefix and `Checkpoint.save_counter`. prefix and `Checkpoint.save_counter`.
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
The full path to the checkpoint. The full path to the checkpoint.
""" """
options = options or checkpoint_options.CheckpointOptions()
graph_building = not context.executing_eagerly() graph_building = not context.executing_eagerly()
if graph_building: if graph_building:
if ops.inside_function(): if ops.inside_function():
@ -1931,7 +1984,8 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint_number = session.run(self._save_assign_op) checkpoint_number = session.run(self._save_assign_op)
else: else:
checkpoint_number = assign_op.numpy() checkpoint_number = assign_op.numpy()
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number)) file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
options=options)
checkpoint_management.update_checkpoint_state_internal( checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(file_prefix), save_dir=os.path.dirname(file_prefix),
model_checkpoint_path=file_path, model_checkpoint_path=file_path,
@ -1939,7 +1993,7 @@ class Checkpoint(tracking.AutoTrackable):
save_relative_paths=True) save_relative_paths=True)
return file_path return file_path
def read(self, save_path): def read(self, save_path, options=None):
"""Read a training checkpoint written with `write`. """Read a training checkpoint written with `write`.
Reads this `Checkpoint` and any objects it depends on. Reads this `Checkpoint` and any objects it depends on.
@ -1962,18 +2016,25 @@ class Checkpoint(tracking.AutoTrackable):
# Later, load the checkpoint with read() # Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed. # With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed() checkpoint.read(path).assert_consumed()
# You can also pass options to restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.read(path, options=options)
``` ```
Args: Args:
save_path: The path to the checkpoint as returned by `write`. save_path: The path to the checkpoint as returned by `write`.
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See `restore` for details. status of a checkpoint restoration. See `restore` for details.
""" """
return self._saver.restore(save_path=save_path) options = options or checkpoint_options.CheckpointOptions()
return self._saver.restore(save_path=save_path, options=options)
def restore(self, save_path): def restore(self, save_path, options=None):
"""Restore a training checkpoint. """Restore a training checkpoint.
Restores this `Checkpoint` and any objects it depends on. Restores this `Checkpoint` and any objects it depends on.
@ -1995,6 +2056,10 @@ class Checkpoint(tracking.AutoTrackable):
```python ```python
checkpoint = tf.train.Checkpoint( ... ) checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path).assert_consumed() checkpoint.restore(path).assert_consumed()
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()
``` ```
An exception will be raised if any Python objects in the dependency graph An exception will be raised if any Python objects in the dependency graph
@ -2011,6 +2076,7 @@ class Checkpoint(tracking.AutoTrackable):
`tf.train.latest_checkpoint`. If the checkpoint was written by the `tf.train.latest_checkpoint`. If the checkpoint was written by the
name-based `tf.compat.v1.train.Saver`, names are used to match name-based `tf.compat.v1.train.Saver`, names are used to match
variables. variables.
options: Optional `tf.train.CheckpointOptions` object.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
@ -2049,7 +2115,7 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint file or object when the `Checkpoint` object is deleted checkpoint file or object when the `Checkpoint` object is deleted
(often at program shutdown). (often at program shutdown).
""" """
status = self.read(save_path) status = self.read(save_path, options=options)
# Create the save counter now so it gets initialized with other variables # Create the save counter now so it gets initialized with other variables
# when graph building. Creating it earlier would lead to errors when using, # when graph building. Creating it earlier would lead to errors when using,
# say, train.Saver() to save the model before initializing it. # say, train.Saver() to save the model before initializing it.

View File

@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
@ -409,6 +410,28 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
del ckpt del ckpt
status.assert_consumed() status.assert_consumed()
@test_util.run_in_graph_and_eager_modes
def testPassingCheckpointOptions(self):
localhost = "/job:localhost/device:CPU:0"
options = checkpoint_options.CheckpointOptions(
experimental_io_device=localhost)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
v = variable_scope.get_variable(name="v", initializer=0.)
self.evaluate(v.initializer)
ckpt = trackable_utils.Checkpoint(v=v)
self.evaluate(trackable_utils.gather_initializers(ckpt))
save_path = ckpt.save(file_prefix=prefix, options=options)
status = ckpt.restore(save_path=save_path, options=options)
del ckpt
status.assert_consumed()
# In graph mode, verify that the save and restore ops were set to run on
# localhost.
if not context.executing_eagerly():
for op in ops.get_default_graph().get_operations():
if op.type in ("SaveV2", "RestoreV2"):
self.assertEqual(localhost, op.device)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testSaveRestore(self): def testSaveRestore(self):
model = MyModel() model = MyModel()

View File

@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>" is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member { member {
name: "function_aliases" name: "function_aliases"
mtype: "<type \'member_descriptor\'>" mtype: "<type \'member_descriptor\'>"
@ -16,6 +20,6 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
} }
} }

View File

@ -0,0 +1,13 @@
path: "tensorflow.train.CheckpointOptions"
tf_class {
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -28,6 +28,10 @@ tf_module {
name: "CheckpointManager" name: "CheckpointManager"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "CheckpointOptions"
mtype: "<type \'type\'>"
}
member { member {
name: "CheckpointSaverHook" name: "CheckpointSaverHook"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -2,6 +2,10 @@ path: "tensorflow.saved_model.SaveOptions"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>" is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member { member {
name: "function_aliases" name: "function_aliases"
mtype: "<type \'member_descriptor\'>" mtype: "<type \'member_descriptor\'>"
@ -16,6 +20,6 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
} }
} }

View File

@ -0,0 +1,13 @@
path: "tensorflow.train.CheckpointOptions"
tf_class {
is_instance: "<class \'tensorflow.python.training.saving.checkpoint_options.CheckpointOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -14,18 +14,18 @@ tf_class {
} }
member_method { member_method {
name: "read" name: "read"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "restore" name: "restore"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'save_path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "write" name: "write"
argspec: "args=[\'self\', \'file_prefix\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'file_prefix\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
} }

View File

@ -12,6 +12,10 @@ tf_module {
name: "CheckpointManager" name: "CheckpointManager"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "CheckpointOptions"
mtype: "<type \'type\'>"
}
member { member {
name: "ClusterDef" name: "ClusterDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"