Add a functional saver, use it for object-based checkpointing

Pulls some utilities out of saver.py which are necessary to actually use it. The functional saver takes only SaveableObjects, so these are utilities for taking a list of whatever users pass in and converting them to those.

One other code move for object-based checkpointing to avoid circular imports.

Applications which need a SaverDef still use the old Saver. Serialization to SaverDef will be added to this saver in a followup.

Does not actually wrap the new Saver's methods in @tf.function yet, since there are memory issues which need to be fixed first.

PiperOrigin-RevId: 224561069
This commit is contained in:
Allen Lavoie 2018-12-07 12:40:49 -08:00 committed by TensorFlower Gardener
parent 4f543e588a
commit 66ca3cd10d
14 changed files with 687 additions and 416 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.saving import saveable_object_util
LOCAL_VARIABLE_NAME = 'local_center_variable'
GLOBAL_VARIABLE_NAME = 'global_center_variable'
@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
if var_list is None:
var_list = variables.trainable_variables()
if not isinstance(var_list, dict):
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
var_list = saveable_object_util.op_list_to_dict(var_list)
swapped_var_list = {}
for key, var in var_list.items():
@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
def after_create_session(self, session, coord):
"""Run initialization ops"""
session.run(self._variable_init_op)
session.run(self._variable_init_op)

View File

@ -26,6 +26,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
from tensorflow.python.training.saving import saveable_object_util
class MovingAverageOptimizer(optimizer.Optimizer):
@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer):
if var_list is None:
var_list = variables.global_variables()
if not isinstance(var_list, dict):
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
var_list = saveable_object_util.op_list_to_dict(var_list)
v_name_to_tensor = {}
for k, tensor_or_list in six.iteritems(var_list):

View File

@ -3515,13 +3515,13 @@ py_library(
exclude = [
"**/*test*",
"training/checkpointable/**/*.py",
"training/saving/**/*.py",
# The following targets have their own build rules (same name as the
# file):
"training/basic_session_run_hooks.py",
"training/checkpoint_management.py",
"training/distribute.py",
"training/distribution_strategy_context.py",
"training/saveable_object.py",
"training/saver.py",
"training/session_run_hook.py",
"training/training_util.py",
@ -3596,12 +3596,6 @@ py_library(
],
)
py_library(
name = "saveable_object",
srcs = ["training/saveable_object.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "checkpoint_management",
srcs = ["training/checkpoint_management.py"],
@ -3655,7 +3649,6 @@ py_library(
":platform",
":pywrap_tensorflow",
":resource_variable_ops",
":saveable_object",
":session",
":state_ops",
":string_ops",
@ -3665,6 +3658,8 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -30,7 +30,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util.tf_export import tf_export
@ -311,10 +311,10 @@ def _set_checkpoint_initializer(variable,
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
names_to_saveables = saveable_object_util.op_list_to_dict([variable])
saveable_objects = []
for name, op in names_to_saveables.items():
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
for s in saveable_object_util.saveable_objects_for_op(op, name):
saveable_objects.append(s)
assert len(saveable_objects) == 1 # Should be only one variable.

View File

@ -25,9 +25,9 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops_gen",
"//tensorflow/python:platform",
"//tensorflow/python:saveable_object",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/saving:saveable_object",
],
)
@ -114,7 +114,6 @@ py_library(
"//tensorflow/python:init_ops",
"//tensorflow/python:io_ops_gen",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:saveable_object",
"//tensorflow/python:saver",
"//tensorflow/python:session",
"//tensorflow/python:tensor_shape",
@ -123,6 +122,9 @@ py_library(
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
],
)

View File

@ -25,7 +25,6 @@ import weakref
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -34,7 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saveable_object
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator
@ -374,41 +373,10 @@ class _CheckpointPosition(object):
eagerly.
"""
(restore_ops,
named_saveables,
tensor_saveables,
python_saveables) = self._gather_ops_or_named_saveables()
# Eagerly run restorations for Python state.
reader = pywrap_tensorflow.NewCheckpointReader(
self._checkpoint.save_path_string)
for saveable in python_saveables:
spec_names = [spec.name for spec in saveable.specs]
saveable.python_restore(
[reader.get_tensor(name) for name in spec_names])
# If we have new SaveableObjects, extract and cache restore ops.
if named_saveables:
validated_saveables = (
self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access
validated_names = set(saveable.name for saveable in validated_saveables)
if set(named_saveables.keys()) != validated_names:
raise AssertionError(
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (named_saveables.keys(), validated_names))
all_tensors = self._checkpoint.builder.bulk_restore(
filename_tensor=self._checkpoint.save_path_tensor,
saveables=validated_saveables, preferred_shard=-1,
restore_sequentially=False)
saveable_index = 0
for saveable in validated_saveables:
num_specs = len(saveable.specs)
saveable_tensors = all_tensors[
saveable_index:saveable_index + num_specs]
saveable_index += num_specs
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
if not context.executing_eagerly():
assert saveable.name not in self._checkpoint.restore_ops_by_name
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
restore_ops.extend(self._checkpoint.restore_saveables(
tensor_saveables, python_saveables))
return restore_ops
@property

View File

@ -40,11 +40,14 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@ -89,7 +92,6 @@ class _CheckpointRestoreCoordinator(object):
referenced every restore (e.g. for Python state); otherwise they would
create their own ops every restore.
"""
self.builder = saver_lib.BulkSaverBuilder()
self.object_graph_proto = object_graph_proto
self.restore_uid = ops.uid()
# Maps from objects to lists of attributes which were in the checkpoint but
@ -144,6 +146,57 @@ class _CheckpointRestoreCoordinator(object):
if self.new_restore_ops_callback:
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
def restore_saveables(self, tensor_saveables, python_saveables):
"""Run or build restore operations for SaveableObjects.
Args:
tensor_saveables: `SaveableObject`s which correspond to Tensors.
python_saveables: `PythonStateSaveable`s which correspond to Python
values.
Returns:
When graph building, a list of restore operations, either cached or newly
created, to restore `tensor_saveables`.
"""
restore_ops = []
# Eagerly run restorations for Python state.
reader = pywrap_tensorflow.NewCheckpointReader(
self.save_path_string)
for saveable in python_saveables:
spec_names = [spec.name for spec in saveable.specs]
saveable.python_restore(
[reader.get_tensor(name) for name in spec_names])
# If we have new SaveableObjects, extract and cache restore ops.
if tensor_saveables:
validated_saveables = saveable_object_util.validate_and_slice_inputs(
tensor_saveables)
validated_names = set(saveable.name for saveable in validated_saveables)
if set(tensor_saveables.keys()) != validated_names:
raise AssertionError(
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (tensor_saveables.keys(), validated_names))
for saveable in validated_saveables:
if saveable.device:
device = saveable_object_util.set_cpu0(saveable.device)
else:
device = None
with ops.device(device):
tensors = []
for spec in saveable.specs:
tensors.append(
io_ops.restore_v2(
self.save_path_tensor,
[spec.name],
[spec.slice_spec],
[spec.dtype])[0])
restore_op = saveable.restore(tensors, restored_shapes=None)
if not context.executing_eagerly():
assert saveable.name not in self.restore_ops_by_name
self.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
return restore_ops
class _NameBasedRestoreCoordinator(object):
"""Keeps the status of a name-based checkpoint restore."""
@ -183,11 +236,11 @@ class _NameBasedRestoreCoordinator(object):
continue
else:
saveable = saveable_factory
names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict(
names_to_saveables = saveable_object_util.op_list_to_dict(
[saveable],
convert_variable_to_tensor=False)
for name, op in names_to_saveables.items():
for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
for saveable_object in saveable_object_util.saveable_objects_for_op(
op=op, name=name):
yield saveable_object
@ -606,10 +659,10 @@ def _add_attributes_to_object_graph(
# Figure out the name-based Saver's name for this variable. If it's
# already a SaveableObject we'd just get the checkpoint key back, so
# we leave full_name blank.
saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
saver_dict = saveable_object_util.op_list_to_dict(
[maybe_saveable], convert_variable_to_tensor=False)
full_name, = saver_dict.keys()
saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
saveables = tuple(saveable_object_util.saveable_objects_for_op(
op=maybe_saveable, name=attribute.checkpoint_key))
for saveable in saveables:
saveable.full_name = full_name
@ -1226,7 +1279,7 @@ class NameBasedSaverStatus(_LoadStatus):
session = ops.get_default_session()
with ops.device("/cpu:0"):
saveables = self._gather_saveable_objects()
saver_lib.Saver(saveables).restore(
v1_saver_lib.Saver(saveables).restore(
sess=session, save_path=self._checkpoint.save_path)
def initialize_or_restore(self, session=None):
@ -1251,18 +1304,6 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
fetches=fetches, feed_dict=feed_dict, **kwargs)
def _copy_saver_with_new_var_list(old_saver, new_var_list):
"""Copy a `tf.train.Saver`'s state to a new Saver with different variables."""
new_saver = saver_lib.Saver(var_list=new_var_list, max_to_keep=None)
# TODO(allenl): Move to copying functionality to Saver?
# pylint: disable=protected-access
new_saver._last_checkpoints = old_saver._last_checkpoints
new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted
new_saver._next_checkpoint_time = old_saver._next_checkpoint_time
# pylint: enable=protected-access
return new_saver
class CheckpointableSaver(object):
"""Saves and restores a `Checkpointable` object and its dependencies.
@ -1301,7 +1342,8 @@ class CheckpointableSaver(object):
# Op caching for save
self._object_graph_feed_tensor = None
self._last_save_object_graph = None
self._last_save_saver = None
self._file_prefix_feed_tensor = None
self._cached_save_operation = None
# Op caching for restore, shared between _CheckpointRestoreCoordinators
self._restore_op_cache = {}
@ -1368,13 +1410,16 @@ class CheckpointableSaver(object):
base.NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
# TODO(allenl, haoliang): Swap in a function-based saver here.
return saver_lib.Saver(
# TODO(allenl): Swap in a function-based saver here once it can serialize
# to a SaverDef.
return v1_saver_lib.Saver(
var_list=named_saveable_objects, max_to_keep=None)
def _prepare_save(self,
object_graph_tensor=None,
saveable_object_cache=None):
def _save_cached_when_graph_building(
self,
file_prefix,
object_graph_tensor=None,
saveable_object_cache=None):
"""Create or retrieve save ops.
When graph building, `saveable_object_cache` will typically be non-`None`,
@ -1383,15 +1428,17 @@ class CheckpointableSaver(object):
unnecessarily re-creating save ops.
Args:
file_prefix: The prefix for saved checkpoint files.
object_graph_tensor: A `Tensor` to which the current object graph will be
fed.
saveable_object_cache: A dictionary; if specified, used to cache
`SaveableObject`s.
Returns:
A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
to feed when running save ops. The feed dict contains the current object
graph and any Python state to be saved in the checkpoint.
A two-element tuple with a filename tensor and a feed_dict of tensors to
feed when running it (if graph building). The feed dict contains the
current object graph and any Python state to be saved in the
checkpoint. When executing eagerly only the first argument is meaningful.
"""
(named_saveable_objects, graph_proto,
feed_additions) = self._gather_saveables(
@ -1403,15 +1450,11 @@ class CheckpointableSaver(object):
# constructors. That means the Saver needs to be copied with a new
# var_list.
or context.executing_eagerly()):
if self._last_save_object_graph is not None:
self._last_save_saver = _copy_saver_with_new_var_list(
old_saver=self._last_save_saver,
new_var_list=named_saveable_objects)
else:
self._last_save_saver = saver_lib.Saver(
var_list=named_saveable_objects, max_to_keep=None)
saver = functional_saver.Saver(named_saveable_objects)
with ops.device("/cpu:0"):
self._cached_save_operation = saver.save(file_prefix)
self._last_save_object_graph = graph_proto
return self._last_save_saver, feed_additions
return self._cached_save_operation, feed_additions
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@ -1435,36 +1478,42 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
feed_additions = {}
feed_dict = {}
graph_building = not context.executing_eagerly()
if checkpoint_number:
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
self._file_prefix_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
file_prefix_tensor = self._file_prefix_feed_tensor
feed_dict[file_prefix_tensor] = file_prefix
else:
with ops.device("/cpu:0"):
file_prefix_tensor = constant_op.constant(
file_prefix, dtype=dtypes.string)
object_graph_tensor = None
saver, new_feed_additions = self._prepare_save(
file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor,
object_graph_tensor=object_graph_tensor,
saveable_object_cache=self._saveable_object_cache)
if new_feed_additions:
feed_additions.update(new_feed_additions)
feed_dict.update(new_feed_additions)
if not graph_building:
session = None
elif session is None:
session = ops.get_default_session()
file_io.recursive_create_dir(os.path.dirname(file_prefix))
with ops.device("/cpu:0"):
save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
write_meta_graph=False,
write_state=False,
global_step=checkpoint_number)
if session:
save_path = session.run(save_path, feed_dict=feed_dict)
else:
save_path = save_path.numpy()
return save_path
def restore(self, save_path):
@ -1753,9 +1802,9 @@ class Checkpoint(tracking.Checkpointable):
Returns:
The full path to the checkpoint (i.e. `file_prefix`).
"""
return self._saver.save(
return compat.as_str(self._saver.save(
file_prefix=file_prefix,
session=session)
session=session))
@property
def save_counter(self):

View File

@ -14,7 +14,11 @@
# ==============================================================================
# pylint: disable=invalid-name
"""Save and restore variables."""
"""Save and restore variables.
Symbols in this file are deprecated. See replacements in
tensorflow/python/training/checkpointable and tensorflow/python/training/saving.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -25,7 +29,6 @@ import time
import uuid
import numpy as np
import six
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@ -42,16 +45,15 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saveable_object
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@ -67,31 +69,6 @@ get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
remove_checkpoint = checkpoint_management.remove_checkpoint
# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
"VariableV2",
"AutoReloadVariable",
"VarHandleOp",
"ReadVariableOp"])
def _set_cpu0(device_string):
"""Creates a new device string based on `device_string` but using /CPU:0.
If the device is already on /CPU:0, this is a no-op.
Args:
device_string: A device string.
Returns:
A device string.
"""
parsed_device = pydev.DeviceSpec.from_string(device_string)
parsed_device.device_type = "CPU"
parsed_device.device_index = 0
return parsed_device.to_string()
class BaseSaverBuilder(object):
"""Base class for Savers.
@ -101,64 +78,9 @@ class BaseSaverBuilder(object):
SaveSpec = saveable_object.SaveSpec
SaveableObject = saveable_object.SaveableObject
class VariableSaveable(SaveableObject):
"""SaveableObject implementation that handles Variables."""
def __init__(self, var, slice_spec, name):
spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype)
super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
restored_tensor = restored_tensors[0]
if restored_shapes is not None:
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
return state_ops.assign(
self.op,
restored_tensor,
validate_shape=restored_shapes is None and
self.op.get_shape().is_fully_defined())
class ResourceVariableSaveable(SaveableObject):
"""SaveableObject implementation that handles ResourceVariables."""
def __init__(self, var, slice_spec, name):
self._var_device = var.device
self._var_shape = var.shape
if isinstance(var, ops.Tensor):
self.handle_op = var.op.inputs[0]
tensor = var
elif isinstance(var, resource_variable_ops.ResourceVariable):
def _read_variable_closure(v):
def f():
with ops.device(v.device):
x = v.read_value()
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
with ops.device("/device:CPU:0"):
return array_ops.identity(x)
return f
self.handle_op = var.handle
tensor = _read_variable_closure(var)
else:
raise ValueError(
"Saveable is neither a resource variable nor a read operation."
" Got: %s" % repr(var))
spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name,
dtype=var.dtype)
super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
restored_tensor = restored_tensors[0]
if restored_shapes is not None:
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
# Copy the restored tensor to the variable's device.
with ops.device(self._var_device):
restored_tensor = array_ops.identity(restored_tensor)
return resource_variable_ops.shape_safe_assign_variable_handle(
self.handle_op, self._var_shape, restored_tensor)
# Aliases for code which was moved but still has lots of users.
VariableSaveable = saveable_object_util.ReferenceVariableSaveable
ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version
@ -224,7 +146,11 @@ class BaseSaverBuilder(object):
del restore_sequentially
all_tensors = []
for saveable in saveables:
with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
if saveable.device:
device = saveable_object_util.set_cpu0(saveable.device)
else:
device = None
with ops.device(device):
all_tensors.extend(
self.restore_op(filename_tensor, saveable, preferred_shard))
return all_tensors
@ -336,7 +262,7 @@ class BaseSaverBuilder(object):
last_device = None
for shard, (device, saveables) in enumerate(per_device):
last_device = device
with ops.device(_set_cpu0(device)):
with ops.device(saveable_object_util.set_cpu0(device)):
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
sharded_prefixes.append(sharded_filename)
@ -344,7 +270,7 @@ class BaseSaverBuilder(object):
with ops.control_dependencies([x.op for x in sharded_saves]):
# Co-locates the merge step with the last device.
with ops.device(_set_cpu0(last_device)):
with ops.device(saveable_object_util.set_cpu0(last_device)):
# V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
merge_step = gen_io_ops.merge_v2_checkpoints(
@ -459,10 +385,6 @@ class BaseSaverBuilder(object):
name="restore_shard"))
return control_flow_ops.group(*sharded_restores, name="restore_all")
@staticmethod
def _IsVariable(v):
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
def _GroupByDevices(self, saveables):
"""Group Variable tensor slices per device.
@ -490,220 +412,6 @@ class BaseSaverBuilder(object):
per_device[canonical_device.pop()].append(saveable)
return sorted(per_device.items(), key=lambda t: t[0])
@staticmethod
def OpListToDict(op_list, convert_variable_to_tensor=True):
"""Create a dictionary of names to operation lists.
Args:
op_list: A list, tuple, or set of Variables or SaveableObjects.
convert_variable_to_tensor: Whether or not to convert single Variables
with no slice info into Tensors.
Returns:
A dictionary of names to the operations that must be saved under
that name. Variables with save_slice_info are grouped together under the
same key in no particular order.
Raises:
TypeError: If the type of op_list or its elements is not supported.
ValueError: If at least two saveables share the same name.
"""
if not isinstance(op_list, (list, tuple, set)):
raise TypeError("Variables to save should be passed in a dict or a "
"list: %s" % op_list)
# When ResourceVariables are converted to Tensors, read ops are added to the
# graph. Sorting the op_list ensures that the resulting graph is always
# constructed in a deterministic way:
op_list = sorted(op_list, key=lambda x: x.name)
names_to_saveables = {}
# pylint: disable=protected-access
for var in op_list:
if isinstance(var, BaseSaverBuilder.SaveableObject):
names_to_saveables[var.name] = var
elif isinstance(var, variables.PartitionedVariable):
if var.name in names_to_saveables:
raise ValueError("At least two variables have the same name: %s" %
var.name)
names_to_saveables[var.name] = var
elif isinstance(var, variables.Variable) and var._save_slice_info:
name = var._save_slice_info.full_name
if name in names_to_saveables:
if not isinstance(names_to_saveables[name], list):
raise ValueError("Mixing slices and non-slices with the same name: "
"%s" % name)
names_to_saveables[name].append(var)
else:
names_to_saveables[name] = [var]
elif (isinstance(var, checkpointable.CheckpointableBase)
and not isinstance(var, variables.Variable)):
checkpointable_saveables = [
(factory() if callable(factory) else factory)
for factory in var._gather_saveables_for_checkpoint().values()]
names_to_saveables.update(
BaseSaverBuilder.OpListToDict(checkpointable_saveables))
else:
if context.executing_eagerly():
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"Can only save/restore ResourceVariables when eager execution "
"is enabled, type: %s." % type(var))
set_var = names_to_saveables.setdefault(var._shared_name, var)
if set_var is not var:
raise ValueError(
("Two different ResourceVariable objects with the same "
"shared_name '%s' were passed to the Saver. This likely means "
"that they were created in different Graphs or isolation "
"contexts, and may not be checkpointed together.") %
(var._shared_name,))
else:
if convert_variable_to_tensor:
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var._graph_element # pylint: disable=protected-access
else:
var = ops.internal_convert_to_tensor(var, as_ref=True)
if not BaseSaverBuilder._IsVariable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
if var.op.type == "ReadVariableOp":
name = var.op.inputs[0].op.name
else:
name = var.op.name
if name in names_to_saveables:
raise ValueError("At least two variables have the same name: %s" %
name)
names_to_saveables[name] = var
# pylint: enable=protected-access
return names_to_saveables
@staticmethod
def SaveableObjectsForOp(op, name):
"""Create `SaveableObject`s from an operation.
Args:
op: A variable, operation, or SaveableObject to coerce into a
SaveableObject.
name: A string name for the SaveableObject.
Yields:
`SaveableObject`s which together save/restore `op`.
Raises:
TypeError: If `name` is not a string.
ValueError: For operations with no known conversion to SaveableObject.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"names_to_saveables must be a dict mapping string names to "
"checkpointable operations. Name is not a string: %s" % name)
if isinstance(op, BaseSaverBuilder.SaveableObject):
yield op
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
if isinstance(op, variables.PartitionedVariable):
op = list(op)
# A set of slices.
slice_name = None
# pylint: disable=protected-access
for variable in op:
if not isinstance(variable, variables.Variable):
raise ValueError("Slices must all be Variables: %s" % variable)
if not variable._save_slice_info:
raise ValueError("Slices must all be slices: %s" % variable)
if slice_name is None:
slice_name = variable._save_slice_info.full_name
elif slice_name != variable._save_slice_info.full_name:
raise ValueError(
"Slices must all be from the same tensor: %s != %s" %
(slice_name, variable._save_slice_info.full_name))
if variable.op.type in ["Variable", "VariableV2",
"AutoReloadVariable"]:
yield BaseSaverBuilder.VariableSaveable(
variable, variable._save_slice_info.spec, name)
else:
yield BaseSaverBuilder.ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
op, variables.Variable):
# pylint: disable=protected-access
for attr, factory in op._gather_saveables_for_checkpoint().items():
if attr == checkpointable.VARIABLE_VALUE_KEY:
# Keep original name for classes masquerading as variables.
full_name = name
else:
full_name = name + "_" + attr
op = (factory(full_name) if callable(factory) else factory)
for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
yield op
# pylint: enable=protected-access
else:
# A variable or tensor.
if context.executing_eagerly():
if not isinstance(op, resource_variable_ops.ResourceVariable):
raise ValueError("Can only save/restore ResourceVariable eager "
"mode is enabled, type: %s." % type(op))
yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
else:
if isinstance(op, resource_variable_ops.ResourceVariable):
variable = op._graph_element # pylint: disable=protected-access
else:
variable = ops.internal_convert_to_tensor(op, as_ref=True)
if not BaseSaverBuilder._IsVariable(variable):
raise TypeError("names_to_saveables must be a dict mapping string "
"names to Tensors/Variables. Not a variable: %s" %
variable)
if variable.op.type in ["Variable", "VariableV2",
"AutoReloadVariable"]:
yield BaseSaverBuilder.VariableSaveable(variable, "", name)
else:
yield BaseSaverBuilder.ResourceVariableSaveable(
variable, "", name)
def _ValidateAndSliceInputs(self, names_to_saveables):
"""Returns the variables and names that will be used for a Saver.
Args:
names_to_saveables: A dict (k, v) where k is the name of an operation and
v is an operation to save or a BaseSaverBuilder.Saver.
Returns:
A list of BaseSaverBuilder.SaveableObject objects.
Raises:
TypeError: If any of the keys are not strings or any of the
values are not one of Tensor or Variable or a checkpointable operation.
ValueError: If the same operation is given in more than one value
(this also applies to slices of SlicedVariables).
"""
if not isinstance(names_to_saveables, dict):
names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
saveables = []
seen_ops = set()
for name, op in sorted(names_to_saveables.items(),
# Avoid comparing ops, sort only by name.
key=lambda x: x[0]):
for converted_saveable_object in self.SaveableObjectsForOp(op, name):
self._AddSaveable(saveables, seen_ops, converted_saveable_object)
return saveables
def _AddSaveable(self, saveables, seen_ops, saveable):
"""Adds the saveable to the saveables list.
Args:
saveables: List to append the SaveableObject to.
seen_ops: Set of the ops of the saveables already processed. Used to
check that each saveable is only saved once.
saveable: The saveable.
Raises:
ValueError: If the saveable has already been processed.
"""
if saveable.op in seen_ops:
raise ValueError("The same saveable will be restored with two names: %s" %
saveable.name)
saveables.append(saveable)
seen_ops.add(saveable.op)
def build(self,
names_to_saveables,
reshape=False,
@ -775,7 +483,8 @@ class BaseSaverBuilder(object):
raise ValueError("save and restore operations need to be built together "
" when eager execution is not enabled.")
saveables = self._ValidateAndSliceInputs(names_to_saveables)
saveables = saveable_object_util.validate_and_slice_inputs(
names_to_saveables)
if max_to_keep is None:
max_to_keep = 0
@ -1910,7 +1619,7 @@ def saver_from_object_based_checkpoint(
if builder is None:
builder = BulkSaverBuilder()
saveables = builder._ValidateAndSliceInputs(var_list) # pylint: disable=protected-access
saveables = saveable_object_util.validate_and_slice_inputs(var_list)
current_names = set()
for saveable in saveables:
for spec in saveable.specs:

View File

@ -0,0 +1,55 @@
# Description:
# Low-level utilities for reading and writing checkpoints.
package(
default_visibility = [
"//tensorflow:internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "functional_saver",
srcs = ["functional_saver.py"],
srcs_version = "PY2AND3",
deps = [
":saveable_object",
":saveable_object_util",
"//tensorflow/python/eager:def_function",
],
)
cuda_py_test(
name = "functional_saver_test",
size = "medium",
srcs = [
"functional_saver_test.py",
],
additional_deps = [
":functional_saver",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "saveable_object",
srcs = ["saveable_object.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "saveable_object_util",
srcs = ["saveable_object_util.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
],
)

View File

@ -0,0 +1,101 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Saves and restore variables inside traced @tf.functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
class Saver(object):
"""A minimal utility class for saving and restoring checkpoints.
Note that this is a low-level utility which stores Tensors in the keys
specified by `SaveableObject`s. Higher-level utilities for object-based
checkpointing are built on top of it.
"""
def __init__(self, saveable_objects):
"""Specify a list of `SaveableObject`s to save and restore.
Args:
saveable_objects: A list of `SaveableObject`s.
"""
saveable_objects = list(saveable_objects)
for saveable in saveable_objects:
if not isinstance(saveable, saveable_object.SaveableObject):
raise ValueError(
"Saver expected a list of SaveableObjects, got %s." % (saveable,))
self._saveable_objects = saveable_objects
# TODO(b/120569892): Use tf.function here
def save(self, file_prefix):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
Returns:
A scalar string Tensor containing `file_prefix` with control dependencies
on the save ops.
"""
tensor_names = []
tensors = []
tensor_slices = []
for saveable in self._saveable_objects:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
with ops.control_dependencies(
[io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)]):
return array_ops.identity(file_prefix)
# TODO(b/120569892): Use tf.function here
def restore(self, file_prefix):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
Returns:
An operation which restores the `Saver`'s `SaveableObject`s when run, or
None if executing eagerly.
"""
restore_ops = []
for saveable in self._saveable_objects:
if saveable.device:
device = saveable_object_util.set_cpu0(saveable.device)
else:
device = None
with ops.device(device):
tensors = []
for spec in saveable.specs:
tensors.append(
io_ops.restore_v2(
file_prefix,
[spec.name],
[spec.slice_spec],
[spec.dtype])[0])
restore_ops.append(saveable.restore(tensors, restored_shapes=None))
return control_flow_ops.group(restore_ops)

View File

@ -0,0 +1,50 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for the functional saver."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util
class SaverTest(test.TestCase):
def test_resource_variable(self):
v1 = resource_variable_ops.ResourceVariable(2.)
saver = functional_saver.Saver(
saveable_object_util.saveable_objects_for_op(v1, "x"))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_path = saver.save(constant_op.constant(prefix))
v1.assign(1.)
saver.restore(save_path)
self.assertEqual(2., self.evaluate(v1))
v2 = resource_variable_ops.ResourceVariable(3.)
second_saver = functional_saver.Saver(
saveable_object_util.saveable_objects_for_op(v2, "x"))
second_saver.restore(save_path)
self.assertEqual(2., self.evaluate(v2))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,340 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for working with and creating SaveableObjects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.eager import context
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saving import saveable_object
# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
"VariableV2",
"AutoReloadVariable",
"VarHandleOp",
"ReadVariableOp"])
def set_cpu0(device_string):
"""Creates a new device string based on `device_string` but using /CPU:0.
If the device is already on /CPU:0, this is a no-op.
Args:
device_string: A device string.
Returns:
A device string.
"""
parsed_device = pydev.DeviceSpec.from_string(device_string)
parsed_device.device_type = "CPU"
parsed_device.device_index = 0
return parsed_device.to_string()
class ReferenceVariableSaveable(saveable_object.SaveableObject):
"""SaveableObject implementation that handles reference variables."""
def __init__(self, var, slice_spec, name):
spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
restored_tensor = restored_tensors[0]
if restored_shapes is not None:
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
return state_ops.assign(
self.op,
restored_tensor,
validate_shape=restored_shapes is None and
self.op.get_shape().is_fully_defined())
class ResourceVariableSaveable(saveable_object.SaveableObject):
"""SaveableObject implementation that handles ResourceVariables."""
def __init__(self, var, slice_spec, name):
self._var_device = var.device
self._var_shape = var.shape
if isinstance(var, ops.Tensor):
self.handle_op = var.op.inputs[0]
tensor = var
elif isinstance(var, resource_variable_ops.ResourceVariable):
def _read_variable_closure(v):
def f():
with ops.device(v.device):
x = v.read_value()
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
with ops.device("/device:CPU:0"):
return array_ops.identity(x)
return f
self.handle_op = var.handle
tensor = _read_variable_closure(var)
else:
raise ValueError(
"Saveable is neither a resource variable nor a read operation."
" Got: %s" % repr(var))
spec = saveable_object.SaveSpec(tensor, slice_spec, name,
dtype=var.dtype)
super(ResourceVariableSaveable, self).__init__(var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
restored_tensor = restored_tensors[0]
if restored_shapes is not None:
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
# Copy the restored tensor to the variable's device.
with ops.device(self._var_device):
restored_tensor = array_ops.identity(restored_tensor)
return resource_variable_ops.shape_safe_assign_variable_handle(
self.handle_op, self._var_shape, restored_tensor)
def _tensor_comes_from_variable(v):
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
def saveable_objects_for_op(op, name):
"""Create `SaveableObject`s from an operation.
Args:
op: A variable, operation, or SaveableObject to coerce into a
SaveableObject.
name: A string name for the SaveableObject.
Yields:
`SaveableObject`s which together save/restore `op`.
Raises:
TypeError: If `name` is not a string.
ValueError: For operations with no known conversion to SaveableObject.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"names_to_saveables must be a dict mapping string names to "
"checkpointable operations. Name is not a string: %s" % name)
if isinstance(op, saveable_object.SaveableObject):
yield op
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
if isinstance(op, variables.PartitionedVariable):
op = list(op)
# A set of slices.
slice_name = None
# pylint: disable=protected-access
for variable in op:
if not isinstance(variable, variables.Variable):
raise ValueError("Slices must all be Variables: %s" % variable)
if not variable._save_slice_info:
raise ValueError("Slices must all be slices: %s" % variable)
if slice_name is None:
slice_name = variable._save_slice_info.full_name
elif slice_name != variable._save_slice_info.full_name:
raise ValueError(
"Slices must all be from the same tensor: %s != %s" %
(slice_name, variable._save_slice_info.full_name))
if variable.op.type in ["Variable", "VariableV2",
"AutoReloadVariable"]:
yield ReferenceVariableSaveable(
variable, variable._save_slice_info.spec, name)
else:
yield ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
op, variables.Variable):
# pylint: disable=protected-access
for attr, factory in op._gather_saveables_for_checkpoint().items():
if attr == checkpointable.VARIABLE_VALUE_KEY:
# Keep original name for classes masquerading as variables.
full_name = name
else:
full_name = name + "_" + attr
op = (factory(full_name) if callable(factory) else factory)
for op in saveable_objects_for_op(op, op.name):
yield op
# pylint: enable=protected-access
else:
# A variable or tensor.
if isinstance(op, resource_variable_ops.ResourceVariable):
# pylint: disable=protected-access
if op._in_graph_mode:
variable = op._graph_element
else:
variable = op
# pylint: enable=protected-access
yield ResourceVariableSaveable(variable, "", name)
else:
with ops.init_scope():
if context.executing_eagerly():
raise ValueError("Can only save/restore ResourceVariables when "
"executing eagerly, got type: %s." % type(op))
variable = ops.internal_convert_to_tensor(op, as_ref=True)
if not _tensor_comes_from_variable(variable):
raise TypeError("names_to_saveables must be a dict mapping string "
"names to Tensors/Variables. Not a variable: %s" %
variable)
if variable.op.type in ["Variable", "VariableV2",
"AutoReloadVariable"]:
yield ReferenceVariableSaveable(variable, "", name)
else:
yield ResourceVariableSaveable(
variable, "", name)
def op_list_to_dict(op_list, convert_variable_to_tensor=True):
"""Create a dictionary of names to operation lists.
Args:
op_list: A list, tuple, or set of Variables or SaveableObjects.
convert_variable_to_tensor: Whether or not to convert single Variables
with no slice info into Tensors.
Returns:
A dictionary of names to the operations that must be saved under
that name. Variables with save_slice_info are grouped together under the
same key in no particular order.
Raises:
TypeError: If the type of op_list or its elements is not supported.
ValueError: If at least two saveables share the same name.
"""
if not isinstance(op_list, (list, tuple, set)):
raise TypeError("Variables to save should be passed in a dict or a "
"list: %s" % op_list)
# When ResourceVariables are converted to Tensors, read ops are added to the
# graph. Sorting the op_list ensures that the resulting graph is always
# constructed in a deterministic way:
op_list = sorted(op_list, key=lambda x: x.name)
names_to_saveables = {}
# pylint: disable=protected-access
for var in op_list:
if isinstance(var, saveable_object.SaveableObject):
names_to_saveables[var.name] = var
elif isinstance(var, variables.PartitionedVariable):
if var.name in names_to_saveables:
raise ValueError("At least two variables have the same name: %s" %
var.name)
names_to_saveables[var.name] = var
elif isinstance(var, variables.Variable) and var._save_slice_info:
name = var._save_slice_info.full_name
if name in names_to_saveables:
if not isinstance(names_to_saveables[name], list):
raise ValueError("Mixing slices and non-slices with the same name: "
"%s" % name)
names_to_saveables[name].append(var)
else:
names_to_saveables[name] = [var]
elif (isinstance(var, checkpointable.CheckpointableBase)
and not isinstance(var, variables.Variable)):
checkpointable_saveables = [
(factory() if callable(factory) else factory)
for factory in var._gather_saveables_for_checkpoint().values()]
names_to_saveables.update(
op_list_to_dict(checkpointable_saveables))
else:
if context.executing_eagerly():
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"Can only save/restore ResourceVariables when eager execution "
"is enabled, type: %s." % type(var))
set_var = names_to_saveables.setdefault(var._shared_name, var)
if set_var is not var:
raise ValueError(
("Two different ResourceVariable objects with the same "
"shared_name '%s' were passed to the Saver. This likely means "
"that they were created in different Graphs or isolation "
"contexts, and may not be checkpointed together.") %
(var._shared_name,))
else:
if convert_variable_to_tensor:
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var._graph_element # pylint: disable=protected-access
else:
var = ops.internal_convert_to_tensor(var, as_ref=True)
if not _tensor_comes_from_variable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
if var.op.type == "ReadVariableOp":
name = var.op.inputs[0].op.name
else:
name = var.op.name
if name in names_to_saveables:
raise ValueError("At least two variables have the same name: %s" %
name)
names_to_saveables[name] = var
# pylint: enable=protected-access
return names_to_saveables
def _add_saveable(saveables, seen_ops, saveable):
"""Adds the saveable to the saveables list.
Args:
saveables: List to append the SaveableObject to.
seen_ops: Set of the ops of the saveables already processed. Used to
check that each saveable is only saved once.
saveable: The saveable.
Raises:
ValueError: If the saveable has already been processed.
"""
if saveable.op in seen_ops:
raise ValueError("The same saveable will be restored with two names: %s" %
saveable.name)
saveables.append(saveable)
seen_ops.add(saveable.op)
def validate_and_slice_inputs(names_to_saveables):
"""Returns the variables and names that will be used for a Saver.
Args:
names_to_saveables: A dict (k, v) where k is the name of an operation and
v is an operation to save or a BaseSaverBuilder.Saver.
Returns:
A list of SaveableObjects.
Raises:
TypeError: If any of the keys are not strings or any of the
values are not one of Tensor or Variable or a checkpointable operation.
ValueError: If the same operation is given in more than one value
(this also applies to slices of SlicedVariables).
"""
if not isinstance(names_to_saveables, dict):
names_to_saveables = op_list_to_dict(names_to_saveables)
saveables = []
seen_ops = set()
for name, op in sorted(names_to_saveables.items(),
# Avoid comparing ops, sort only by name.
key=lambda x: x[0]):
for converted_saveable_object in saveable_objects_for_op(op, name):
_add_saveable(saveables, seen_ops, converted_saveable_object)
return saveables

View File

@ -28,7 +28,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_ops
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util.tf_export import tf_export
@ -139,7 +139,7 @@ def _infer_var_name(var):
Returns:
Name of the `var`
"""
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
name_to_var_dict = saveable_object_util.op_list_to_dict(var)
if len(name_to_var_dict) > 1:
raise TypeError("`var` = %s passed as arg violates the constraints. "
"name_to_var_dict = %s" % (var, name_to_var_dict))