Add a functional saver, use it for object-based checkpointing
Pulls some utilities out of saver.py which are necessary to actually use it. The functional saver takes only SaveableObjects, so these are utilities for taking a list of whatever users pass in and converting them to those. One other code move for object-based checkpointing to avoid circular imports. Applications which need a SaverDef still use the old Saver. Serialization to SaverDef will be added to this saver in a followup. Does not actually wrap the new Saver's methods in @tf.function yet, since there are memory issues which need to be fixed first. PiperOrigin-RevId: 224561069
This commit is contained in:
parent
4f543e588a
commit
66ca3cd10d
@ -30,6 +30,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.training import optimizer
|
from tensorflow.python.training import optimizer
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
LOCAL_VARIABLE_NAME = 'local_center_variable'
|
LOCAL_VARIABLE_NAME = 'local_center_variable'
|
||||||
GLOBAL_VARIABLE_NAME = 'global_center_variable'
|
GLOBAL_VARIABLE_NAME = 'global_center_variable'
|
||||||
@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
|
|||||||
if var_list is None:
|
if var_list is None:
|
||||||
var_list = variables.trainable_variables()
|
var_list = variables.trainable_variables()
|
||||||
if not isinstance(var_list, dict):
|
if not isinstance(var_list, dict):
|
||||||
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
|
var_list = saveable_object_util.op_list_to_dict(var_list)
|
||||||
|
|
||||||
swapped_var_list = {}
|
swapped_var_list = {}
|
||||||
for key, var in var_list.items():
|
for key, var in var_list.items():
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.training import moving_averages
|
from tensorflow.python.training import moving_averages
|
||||||
from tensorflow.python.training import optimizer
|
from tensorflow.python.training import optimizer
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
|
|
||||||
class MovingAverageOptimizer(optimizer.Optimizer):
|
class MovingAverageOptimizer(optimizer.Optimizer):
|
||||||
@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer):
|
|||||||
if var_list is None:
|
if var_list is None:
|
||||||
var_list = variables.global_variables()
|
var_list = variables.global_variables()
|
||||||
if not isinstance(var_list, dict):
|
if not isinstance(var_list, dict):
|
||||||
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
|
var_list = saveable_object_util.op_list_to_dict(var_list)
|
||||||
|
|
||||||
v_name_to_tensor = {}
|
v_name_to_tensor = {}
|
||||||
for k, tensor_or_list in six.iteritems(var_list):
|
for k, tensor_or_list in six.iteritems(var_list):
|
||||||
|
@ -3515,13 +3515,13 @@ py_library(
|
|||||||
exclude = [
|
exclude = [
|
||||||
"**/*test*",
|
"**/*test*",
|
||||||
"training/checkpointable/**/*.py",
|
"training/checkpointable/**/*.py",
|
||||||
|
"training/saving/**/*.py",
|
||||||
# The following targets have their own build rules (same name as the
|
# The following targets have their own build rules (same name as the
|
||||||
# file):
|
# file):
|
||||||
"training/basic_session_run_hooks.py",
|
"training/basic_session_run_hooks.py",
|
||||||
"training/checkpoint_management.py",
|
"training/checkpoint_management.py",
|
||||||
"training/distribute.py",
|
"training/distribute.py",
|
||||||
"training/distribution_strategy_context.py",
|
"training/distribution_strategy_context.py",
|
||||||
"training/saveable_object.py",
|
|
||||||
"training/saver.py",
|
"training/saver.py",
|
||||||
"training/session_run_hook.py",
|
"training/session_run_hook.py",
|
||||||
"training/training_util.py",
|
"training/training_util.py",
|
||||||
@ -3596,12 +3596,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "saveable_object",
|
|
||||||
srcs = ["training/saveable_object.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "checkpoint_management",
|
name = "checkpoint_management",
|
||||||
srcs = ["training/checkpoint_management.py"],
|
srcs = ["training/checkpoint_management.py"],
|
||||||
@ -3655,7 +3649,6 @@ py_library(
|
|||||||
":platform",
|
":platform",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
":resource_variable_ops",
|
":resource_variable_ops",
|
||||||
":saveable_object",
|
|
||||||
":session",
|
":session",
|
||||||
":state_ops",
|
":state_ops",
|
||||||
":string_ops",
|
":string_ops",
|
||||||
@ -3665,6 +3658,8 @@ py_library(
|
|||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/training/checkpointable:base",
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
|
"//tensorflow/python/training/saving:saveable_object",
|
||||||
|
"//tensorflow/python/training/saving:saveable_object_util",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -30,7 +30,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -311,10 +311,10 @@ def _set_checkpoint_initializer(variable,
|
|||||||
restore_op = io_ops.restore_v2(
|
restore_op = io_ops.restore_v2(
|
||||||
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||||
|
|
||||||
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
|
names_to_saveables = saveable_object_util.op_list_to_dict([variable])
|
||||||
saveable_objects = []
|
saveable_objects = []
|
||||||
for name, op in names_to_saveables.items():
|
for name, op in names_to_saveables.items():
|
||||||
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
|
for s in saveable_object_util.saveable_objects_for_op(op, name):
|
||||||
saveable_objects.append(s)
|
saveable_objects.append(s)
|
||||||
|
|
||||||
assert len(saveable_objects) == 1 # Should be only one variable.
|
assert len(saveable_objects) == 1 # Should be only one variable.
|
||||||
|
@ -25,9 +25,9 @@ py_library(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:io_ops_gen",
|
"//tensorflow/python:io_ops_gen",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:saveable_object",
|
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/training/saving:saveable_object",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,7 +114,6 @@ py_library(
|
|||||||
"//tensorflow/python:init_ops",
|
"//tensorflow/python:init_ops",
|
||||||
"//tensorflow/python:io_ops_gen",
|
"//tensorflow/python:io_ops_gen",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
"//tensorflow/python:saveable_object",
|
|
||||||
"//tensorflow/python:saver",
|
"//tensorflow/python:saver",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
@ -123,6 +122,9 @@ py_library(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/training/saving:functional_saver",
|
||||||
|
"//tensorflow/python/training/saving:saveable_object",
|
||||||
|
"//tensorflow/python/training/saving:saveable_object_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ import weakref
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -34,7 +33,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_io_ops as io_ops
|
from tensorflow.python.ops import gen_io_ops as io_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.util import serialization
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
@ -374,41 +373,10 @@ class _CheckpointPosition(object):
|
|||||||
eagerly.
|
eagerly.
|
||||||
"""
|
"""
|
||||||
(restore_ops,
|
(restore_ops,
|
||||||
named_saveables,
|
tensor_saveables,
|
||||||
python_saveables) = self._gather_ops_or_named_saveables()
|
python_saveables) = self._gather_ops_or_named_saveables()
|
||||||
|
restore_ops.extend(self._checkpoint.restore_saveables(
|
||||||
# Eagerly run restorations for Python state.
|
tensor_saveables, python_saveables))
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(
|
|
||||||
self._checkpoint.save_path_string)
|
|
||||||
for saveable in python_saveables:
|
|
||||||
spec_names = [spec.name for spec in saveable.specs]
|
|
||||||
saveable.python_restore(
|
|
||||||
[reader.get_tensor(name) for name in spec_names])
|
|
||||||
|
|
||||||
# If we have new SaveableObjects, extract and cache restore ops.
|
|
||||||
if named_saveables:
|
|
||||||
validated_saveables = (
|
|
||||||
self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access
|
|
||||||
validated_names = set(saveable.name for saveable in validated_saveables)
|
|
||||||
if set(named_saveables.keys()) != validated_names:
|
|
||||||
raise AssertionError(
|
|
||||||
("Saveable keys changed when validating. Got back %s, was "
|
|
||||||
"expecting %s") % (named_saveables.keys(), validated_names))
|
|
||||||
all_tensors = self._checkpoint.builder.bulk_restore(
|
|
||||||
filename_tensor=self._checkpoint.save_path_tensor,
|
|
||||||
saveables=validated_saveables, preferred_shard=-1,
|
|
||||||
restore_sequentially=False)
|
|
||||||
saveable_index = 0
|
|
||||||
for saveable in validated_saveables:
|
|
||||||
num_specs = len(saveable.specs)
|
|
||||||
saveable_tensors = all_tensors[
|
|
||||||
saveable_index:saveable_index + num_specs]
|
|
||||||
saveable_index += num_specs
|
|
||||||
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
assert saveable.name not in self._checkpoint.restore_ops_by_name
|
|
||||||
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
|
|
||||||
restore_ops.append(restore_op)
|
|
||||||
return restore_ops
|
return restore_ops
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -40,11 +40,14 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import optimizer as optimizer_lib
|
from tensorflow.python.training import optimizer as optimizer_lib
|
||||||
from tensorflow.python.training import saveable_object as saveable_object_lib
|
from tensorflow.python.training import saver as v1_saver_lib
|
||||||
from tensorflow.python.training import saver as saver_lib
|
|
||||||
from tensorflow.python.training.checkpointable import base
|
from tensorflow.python.training.checkpointable import base
|
||||||
from tensorflow.python.training.checkpointable import data_structures
|
from tensorflow.python.training.checkpointable import data_structures
|
||||||
from tensorflow.python.training.checkpointable import tracking
|
from tensorflow.python.training.checkpointable import tracking
|
||||||
|
from tensorflow.python.training.saving import functional_saver
|
||||||
|
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -89,7 +92,6 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
referenced every restore (e.g. for Python state); otherwise they would
|
referenced every restore (e.g. for Python state); otherwise they would
|
||||||
create their own ops every restore.
|
create their own ops every restore.
|
||||||
"""
|
"""
|
||||||
self.builder = saver_lib.BulkSaverBuilder()
|
|
||||||
self.object_graph_proto = object_graph_proto
|
self.object_graph_proto = object_graph_proto
|
||||||
self.restore_uid = ops.uid()
|
self.restore_uid = ops.uid()
|
||||||
# Maps from objects to lists of attributes which were in the checkpoint but
|
# Maps from objects to lists of attributes which were in the checkpoint but
|
||||||
@ -144,6 +146,57 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
if self.new_restore_ops_callback:
|
if self.new_restore_ops_callback:
|
||||||
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
|
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
|
||||||
|
|
||||||
|
def restore_saveables(self, tensor_saveables, python_saveables):
|
||||||
|
"""Run or build restore operations for SaveableObjects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_saveables: `SaveableObject`s which correspond to Tensors.
|
||||||
|
python_saveables: `PythonStateSaveable`s which correspond to Python
|
||||||
|
values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
When graph building, a list of restore operations, either cached or newly
|
||||||
|
created, to restore `tensor_saveables`.
|
||||||
|
"""
|
||||||
|
restore_ops = []
|
||||||
|
# Eagerly run restorations for Python state.
|
||||||
|
reader = pywrap_tensorflow.NewCheckpointReader(
|
||||||
|
self.save_path_string)
|
||||||
|
for saveable in python_saveables:
|
||||||
|
spec_names = [spec.name for spec in saveable.specs]
|
||||||
|
saveable.python_restore(
|
||||||
|
[reader.get_tensor(name) for name in spec_names])
|
||||||
|
|
||||||
|
# If we have new SaveableObjects, extract and cache restore ops.
|
||||||
|
if tensor_saveables:
|
||||||
|
validated_saveables = saveable_object_util.validate_and_slice_inputs(
|
||||||
|
tensor_saveables)
|
||||||
|
validated_names = set(saveable.name for saveable in validated_saveables)
|
||||||
|
if set(tensor_saveables.keys()) != validated_names:
|
||||||
|
raise AssertionError(
|
||||||
|
("Saveable keys changed when validating. Got back %s, was "
|
||||||
|
"expecting %s") % (tensor_saveables.keys(), validated_names))
|
||||||
|
for saveable in validated_saveables:
|
||||||
|
if saveable.device:
|
||||||
|
device = saveable_object_util.set_cpu0(saveable.device)
|
||||||
|
else:
|
||||||
|
device = None
|
||||||
|
with ops.device(device):
|
||||||
|
tensors = []
|
||||||
|
for spec in saveable.specs:
|
||||||
|
tensors.append(
|
||||||
|
io_ops.restore_v2(
|
||||||
|
self.save_path_tensor,
|
||||||
|
[spec.name],
|
||||||
|
[spec.slice_spec],
|
||||||
|
[spec.dtype])[0])
|
||||||
|
restore_op = saveable.restore(tensors, restored_shapes=None)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
assert saveable.name not in self.restore_ops_by_name
|
||||||
|
self.restore_ops_by_name[saveable.name] = restore_op
|
||||||
|
restore_ops.append(restore_op)
|
||||||
|
return restore_ops
|
||||||
|
|
||||||
|
|
||||||
class _NameBasedRestoreCoordinator(object):
|
class _NameBasedRestoreCoordinator(object):
|
||||||
"""Keeps the status of a name-based checkpoint restore."""
|
"""Keeps the status of a name-based checkpoint restore."""
|
||||||
@ -183,11 +236,11 @@ class _NameBasedRestoreCoordinator(object):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
saveable = saveable_factory
|
saveable = saveable_factory
|
||||||
names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict(
|
names_to_saveables = saveable_object_util.op_list_to_dict(
|
||||||
[saveable],
|
[saveable],
|
||||||
convert_variable_to_tensor=False)
|
convert_variable_to_tensor=False)
|
||||||
for name, op in names_to_saveables.items():
|
for name, op in names_to_saveables.items():
|
||||||
for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
|
for saveable_object in saveable_object_util.saveable_objects_for_op(
|
||||||
op=op, name=name):
|
op=op, name=name):
|
||||||
yield saveable_object
|
yield saveable_object
|
||||||
|
|
||||||
@ -606,10 +659,10 @@ def _add_attributes_to_object_graph(
|
|||||||
# Figure out the name-based Saver's name for this variable. If it's
|
# Figure out the name-based Saver's name for this variable. If it's
|
||||||
# already a SaveableObject we'd just get the checkpoint key back, so
|
# already a SaveableObject we'd just get the checkpoint key back, so
|
||||||
# we leave full_name blank.
|
# we leave full_name blank.
|
||||||
saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
|
saver_dict = saveable_object_util.op_list_to_dict(
|
||||||
[maybe_saveable], convert_variable_to_tensor=False)
|
[maybe_saveable], convert_variable_to_tensor=False)
|
||||||
full_name, = saver_dict.keys()
|
full_name, = saver_dict.keys()
|
||||||
saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
|
saveables = tuple(saveable_object_util.saveable_objects_for_op(
|
||||||
op=maybe_saveable, name=attribute.checkpoint_key))
|
op=maybe_saveable, name=attribute.checkpoint_key))
|
||||||
for saveable in saveables:
|
for saveable in saveables:
|
||||||
saveable.full_name = full_name
|
saveable.full_name = full_name
|
||||||
@ -1226,7 +1279,7 @@ class NameBasedSaverStatus(_LoadStatus):
|
|||||||
session = ops.get_default_session()
|
session = ops.get_default_session()
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
saveables = self._gather_saveable_objects()
|
saveables = self._gather_saveable_objects()
|
||||||
saver_lib.Saver(saveables).restore(
|
v1_saver_lib.Saver(saveables).restore(
|
||||||
sess=session, save_path=self._checkpoint.save_path)
|
sess=session, save_path=self._checkpoint.save_path)
|
||||||
|
|
||||||
def initialize_or_restore(self, session=None):
|
def initialize_or_restore(self, session=None):
|
||||||
@ -1251,18 +1304,6 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
|
|||||||
fetches=fetches, feed_dict=feed_dict, **kwargs)
|
fetches=fetches, feed_dict=feed_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _copy_saver_with_new_var_list(old_saver, new_var_list):
|
|
||||||
"""Copy a `tf.train.Saver`'s state to a new Saver with different variables."""
|
|
||||||
new_saver = saver_lib.Saver(var_list=new_var_list, max_to_keep=None)
|
|
||||||
# TODO(allenl): Move to copying functionality to Saver?
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
new_saver._last_checkpoints = old_saver._last_checkpoints
|
|
||||||
new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted
|
|
||||||
new_saver._next_checkpoint_time = old_saver._next_checkpoint_time
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
return new_saver
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointableSaver(object):
|
class CheckpointableSaver(object):
|
||||||
"""Saves and restores a `Checkpointable` object and its dependencies.
|
"""Saves and restores a `Checkpointable` object and its dependencies.
|
||||||
|
|
||||||
@ -1301,7 +1342,8 @@ class CheckpointableSaver(object):
|
|||||||
# Op caching for save
|
# Op caching for save
|
||||||
self._object_graph_feed_tensor = None
|
self._object_graph_feed_tensor = None
|
||||||
self._last_save_object_graph = None
|
self._last_save_object_graph = None
|
||||||
self._last_save_saver = None
|
self._file_prefix_feed_tensor = None
|
||||||
|
self._cached_save_operation = None
|
||||||
|
|
||||||
# Op caching for restore, shared between _CheckpointRestoreCoordinators
|
# Op caching for restore, shared between _CheckpointRestoreCoordinators
|
||||||
self._restore_op_cache = {}
|
self._restore_op_cache = {}
|
||||||
@ -1368,11 +1410,14 @@ class CheckpointableSaver(object):
|
|||||||
base.NoRestoreSaveable(
|
base.NoRestoreSaveable(
|
||||||
tensor=object_graph_tensor,
|
tensor=object_graph_tensor,
|
||||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||||
# TODO(allenl, haoliang): Swap in a function-based saver here.
|
# TODO(allenl): Swap in a function-based saver here once it can serialize
|
||||||
return saver_lib.Saver(
|
# to a SaverDef.
|
||||||
|
return v1_saver_lib.Saver(
|
||||||
var_list=named_saveable_objects, max_to_keep=None)
|
var_list=named_saveable_objects, max_to_keep=None)
|
||||||
|
|
||||||
def _prepare_save(self,
|
def _save_cached_when_graph_building(
|
||||||
|
self,
|
||||||
|
file_prefix,
|
||||||
object_graph_tensor=None,
|
object_graph_tensor=None,
|
||||||
saveable_object_cache=None):
|
saveable_object_cache=None):
|
||||||
"""Create or retrieve save ops.
|
"""Create or retrieve save ops.
|
||||||
@ -1383,15 +1428,17 @@ class CheckpointableSaver(object):
|
|||||||
unnecessarily re-creating save ops.
|
unnecessarily re-creating save ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
file_prefix: The prefix for saved checkpoint files.
|
||||||
object_graph_tensor: A `Tensor` to which the current object graph will be
|
object_graph_tensor: A `Tensor` to which the current object graph will be
|
||||||
fed.
|
fed.
|
||||||
saveable_object_cache: A dictionary; if specified, used to cache
|
saveable_object_cache: A dictionary; if specified, used to cache
|
||||||
`SaveableObject`s.
|
`SaveableObject`s.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
|
A two-element tuple with a filename tensor and a feed_dict of tensors to
|
||||||
to feed when running save ops. The feed dict contains the current object
|
feed when running it (if graph building). The feed dict contains the
|
||||||
graph and any Python state to be saved in the checkpoint.
|
current object graph and any Python state to be saved in the
|
||||||
|
checkpoint. When executing eagerly only the first argument is meaningful.
|
||||||
"""
|
"""
|
||||||
(named_saveable_objects, graph_proto,
|
(named_saveable_objects, graph_proto,
|
||||||
feed_additions) = self._gather_saveables(
|
feed_additions) = self._gather_saveables(
|
||||||
@ -1403,15 +1450,11 @@ class CheckpointableSaver(object):
|
|||||||
# constructors. That means the Saver needs to be copied with a new
|
# constructors. That means the Saver needs to be copied with a new
|
||||||
# var_list.
|
# var_list.
|
||||||
or context.executing_eagerly()):
|
or context.executing_eagerly()):
|
||||||
if self._last_save_object_graph is not None:
|
saver = functional_saver.Saver(named_saveable_objects)
|
||||||
self._last_save_saver = _copy_saver_with_new_var_list(
|
with ops.device("/cpu:0"):
|
||||||
old_saver=self._last_save_saver,
|
self._cached_save_operation = saver.save(file_prefix)
|
||||||
new_var_list=named_saveable_objects)
|
|
||||||
else:
|
|
||||||
self._last_save_saver = saver_lib.Saver(
|
|
||||||
var_list=named_saveable_objects, max_to_keep=None)
|
|
||||||
self._last_save_object_graph = graph_proto
|
self._last_save_object_graph = graph_proto
|
||||||
return self._last_save_saver, feed_additions
|
return self._cached_save_operation, feed_additions
|
||||||
|
|
||||||
def save(self, file_prefix, checkpoint_number=None, session=None):
|
def save(self, file_prefix, checkpoint_number=None, session=None):
|
||||||
"""Save a training checkpoint.
|
"""Save a training checkpoint.
|
||||||
@ -1435,36 +1478,42 @@ class CheckpointableSaver(object):
|
|||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint.
|
The full path to the checkpoint.
|
||||||
"""
|
"""
|
||||||
feed_additions = {}
|
feed_dict = {}
|
||||||
graph_building = not context.executing_eagerly()
|
graph_building = not context.executing_eagerly()
|
||||||
|
if checkpoint_number:
|
||||||
|
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
|
||||||
if graph_building:
|
if graph_building:
|
||||||
if self._object_graph_feed_tensor is None:
|
if self._object_graph_feed_tensor is None:
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
self._object_graph_feed_tensor = constant_op.constant(
|
self._object_graph_feed_tensor = constant_op.constant(
|
||||||
"", dtype=dtypes.string)
|
"", dtype=dtypes.string)
|
||||||
|
self._file_prefix_feed_tensor = constant_op.constant(
|
||||||
|
"", dtype=dtypes.string)
|
||||||
object_graph_tensor = self._object_graph_feed_tensor
|
object_graph_tensor = self._object_graph_feed_tensor
|
||||||
|
file_prefix_tensor = self._file_prefix_feed_tensor
|
||||||
|
feed_dict[file_prefix_tensor] = file_prefix
|
||||||
else:
|
else:
|
||||||
|
with ops.device("/cpu:0"):
|
||||||
|
file_prefix_tensor = constant_op.constant(
|
||||||
|
file_prefix, dtype=dtypes.string)
|
||||||
object_graph_tensor = None
|
object_graph_tensor = None
|
||||||
|
|
||||||
saver, new_feed_additions = self._prepare_save(
|
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||||
|
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
||||||
|
file_prefix=file_prefix_tensor,
|
||||||
object_graph_tensor=object_graph_tensor,
|
object_graph_tensor=object_graph_tensor,
|
||||||
saveable_object_cache=self._saveable_object_cache)
|
saveable_object_cache=self._saveable_object_cache)
|
||||||
if new_feed_additions:
|
if new_feed_additions:
|
||||||
feed_additions.update(new_feed_additions)
|
feed_dict.update(new_feed_additions)
|
||||||
if not graph_building:
|
if not graph_building:
|
||||||
session = None
|
session = None
|
||||||
elif session is None:
|
elif session is None:
|
||||||
session = ops.get_default_session()
|
session = ops.get_default_session()
|
||||||
|
|
||||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
if session:
|
||||||
with ops.device("/cpu:0"):
|
save_path = session.run(save_path, feed_dict=feed_dict)
|
||||||
save_path = saver.save(
|
else:
|
||||||
sess=_SessionWithFeedDictAdditions(
|
save_path = save_path.numpy()
|
||||||
session=session, feed_additions=feed_additions),
|
|
||||||
save_path=file_prefix,
|
|
||||||
write_meta_graph=False,
|
|
||||||
write_state=False,
|
|
||||||
global_step=checkpoint_number)
|
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def restore(self, save_path):
|
def restore(self, save_path):
|
||||||
@ -1753,9 +1802,9 @@ class Checkpoint(tracking.Checkpointable):
|
|||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint (i.e. `file_prefix`).
|
The full path to the checkpoint (i.e. `file_prefix`).
|
||||||
"""
|
"""
|
||||||
return self._saver.save(
|
return compat.as_str(self._saver.save(
|
||||||
file_prefix=file_prefix,
|
file_prefix=file_prefix,
|
||||||
session=session)
|
session=session))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def save_counter(self):
|
def save_counter(self):
|
||||||
|
@ -14,7 +14,11 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""Save and restore variables."""
|
"""Save and restore variables.
|
||||||
|
|
||||||
|
Symbols in this file are deprecated. See replacements in
|
||||||
|
tensorflow/python/training/checkpointable and tensorflow/python/training/saving.
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -25,7 +29,6 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
|
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
@ -42,16 +45,15 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_io_ops
|
from tensorflow.python.ops import gen_io_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
|
||||||
from tensorflow.python.ops import state_ops
|
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import saveable_object
|
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||||
|
from tensorflow.python.training.saving import saveable_object
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -67,31 +69,6 @@ get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
|
|||||||
remove_checkpoint = checkpoint_management.remove_checkpoint
|
remove_checkpoint = checkpoint_management.remove_checkpoint
|
||||||
|
|
||||||
|
|
||||||
# Op names which identify variable reads which should be saved.
|
|
||||||
_VARIABLE_OPS = set(["Variable",
|
|
||||||
"VariableV2",
|
|
||||||
"AutoReloadVariable",
|
|
||||||
"VarHandleOp",
|
|
||||||
"ReadVariableOp"])
|
|
||||||
|
|
||||||
|
|
||||||
def _set_cpu0(device_string):
|
|
||||||
"""Creates a new device string based on `device_string` but using /CPU:0.
|
|
||||||
|
|
||||||
If the device is already on /CPU:0, this is a no-op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device_string: A device string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A device string.
|
|
||||||
"""
|
|
||||||
parsed_device = pydev.DeviceSpec.from_string(device_string)
|
|
||||||
parsed_device.device_type = "CPU"
|
|
||||||
parsed_device.device_index = 0
|
|
||||||
return parsed_device.to_string()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSaverBuilder(object):
|
class BaseSaverBuilder(object):
|
||||||
"""Base class for Savers.
|
"""Base class for Savers.
|
||||||
|
|
||||||
@ -101,64 +78,9 @@ class BaseSaverBuilder(object):
|
|||||||
SaveSpec = saveable_object.SaveSpec
|
SaveSpec = saveable_object.SaveSpec
|
||||||
SaveableObject = saveable_object.SaveableObject
|
SaveableObject = saveable_object.SaveableObject
|
||||||
|
|
||||||
class VariableSaveable(SaveableObject):
|
# Aliases for code which was moved but still has lots of users.
|
||||||
"""SaveableObject implementation that handles Variables."""
|
VariableSaveable = saveable_object_util.ReferenceVariableSaveable
|
||||||
|
ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
|
||||||
def __init__(self, var, slice_spec, name):
|
|
||||||
spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype)
|
|
||||||
super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
|
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
|
||||||
restored_tensor = restored_tensors[0]
|
|
||||||
if restored_shapes is not None:
|
|
||||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
|
||||||
return state_ops.assign(
|
|
||||||
self.op,
|
|
||||||
restored_tensor,
|
|
||||||
validate_shape=restored_shapes is None and
|
|
||||||
self.op.get_shape().is_fully_defined())
|
|
||||||
|
|
||||||
class ResourceVariableSaveable(SaveableObject):
|
|
||||||
"""SaveableObject implementation that handles ResourceVariables."""
|
|
||||||
|
|
||||||
def __init__(self, var, slice_spec, name):
|
|
||||||
self._var_device = var.device
|
|
||||||
self._var_shape = var.shape
|
|
||||||
if isinstance(var, ops.Tensor):
|
|
||||||
self.handle_op = var.op.inputs[0]
|
|
||||||
tensor = var
|
|
||||||
elif isinstance(var, resource_variable_ops.ResourceVariable):
|
|
||||||
|
|
||||||
def _read_variable_closure(v):
|
|
||||||
def f():
|
|
||||||
with ops.device(v.device):
|
|
||||||
x = v.read_value()
|
|
||||||
# To allow variables placed on non-CPU devices to be checkpointed,
|
|
||||||
# we copy them to CPU on the same machine first.
|
|
||||||
with ops.device("/device:CPU:0"):
|
|
||||||
return array_ops.identity(x)
|
|
||||||
return f
|
|
||||||
|
|
||||||
self.handle_op = var.handle
|
|
||||||
tensor = _read_variable_closure(var)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Saveable is neither a resource variable nor a read operation."
|
|
||||||
" Got: %s" % repr(var))
|
|
||||||
spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name,
|
|
||||||
dtype=var.dtype)
|
|
||||||
super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
|
|
||||||
var, [spec], name)
|
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
|
||||||
restored_tensor = restored_tensors[0]
|
|
||||||
if restored_shapes is not None:
|
|
||||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
|
||||||
# Copy the restored tensor to the variable's device.
|
|
||||||
with ops.device(self._var_device):
|
|
||||||
restored_tensor = array_ops.identity(restored_tensor)
|
|
||||||
return resource_variable_ops.shape_safe_assign_variable_handle(
|
|
||||||
self.handle_op, self._var_shape, restored_tensor)
|
|
||||||
|
|
||||||
def __init__(self, write_version=saver_pb2.SaverDef.V2):
|
def __init__(self, write_version=saver_pb2.SaverDef.V2):
|
||||||
self._write_version = write_version
|
self._write_version = write_version
|
||||||
@ -224,7 +146,11 @@ class BaseSaverBuilder(object):
|
|||||||
del restore_sequentially
|
del restore_sequentially
|
||||||
all_tensors = []
|
all_tensors = []
|
||||||
for saveable in saveables:
|
for saveable in saveables:
|
||||||
with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
|
if saveable.device:
|
||||||
|
device = saveable_object_util.set_cpu0(saveable.device)
|
||||||
|
else:
|
||||||
|
device = None
|
||||||
|
with ops.device(device):
|
||||||
all_tensors.extend(
|
all_tensors.extend(
|
||||||
self.restore_op(filename_tensor, saveable, preferred_shard))
|
self.restore_op(filename_tensor, saveable, preferred_shard))
|
||||||
return all_tensors
|
return all_tensors
|
||||||
@ -336,7 +262,7 @@ class BaseSaverBuilder(object):
|
|||||||
last_device = None
|
last_device = None
|
||||||
for shard, (device, saveables) in enumerate(per_device):
|
for shard, (device, saveables) in enumerate(per_device):
|
||||||
last_device = device
|
last_device = device
|
||||||
with ops.device(_set_cpu0(device)):
|
with ops.device(saveable_object_util.set_cpu0(device)):
|
||||||
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
|
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
|
||||||
num_shards_tensor)
|
num_shards_tensor)
|
||||||
sharded_prefixes.append(sharded_filename)
|
sharded_prefixes.append(sharded_filename)
|
||||||
@ -344,7 +270,7 @@ class BaseSaverBuilder(object):
|
|||||||
|
|
||||||
with ops.control_dependencies([x.op for x in sharded_saves]):
|
with ops.control_dependencies([x.op for x in sharded_saves]):
|
||||||
# Co-locates the merge step with the last device.
|
# Co-locates the merge step with the last device.
|
||||||
with ops.device(_set_cpu0(last_device)):
|
with ops.device(saveable_object_util.set_cpu0(last_device)):
|
||||||
# V2 format write path consists of a metadata merge step. Once merged,
|
# V2 format write path consists of a metadata merge step. Once merged,
|
||||||
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
||||||
merge_step = gen_io_ops.merge_v2_checkpoints(
|
merge_step = gen_io_ops.merge_v2_checkpoints(
|
||||||
@ -459,10 +385,6 @@ class BaseSaverBuilder(object):
|
|||||||
name="restore_shard"))
|
name="restore_shard"))
|
||||||
return control_flow_ops.group(*sharded_restores, name="restore_all")
|
return control_flow_ops.group(*sharded_restores, name="restore_all")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _IsVariable(v):
|
|
||||||
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
|
|
||||||
|
|
||||||
def _GroupByDevices(self, saveables):
|
def _GroupByDevices(self, saveables):
|
||||||
"""Group Variable tensor slices per device.
|
"""Group Variable tensor slices per device.
|
||||||
|
|
||||||
@ -490,220 +412,6 @@ class BaseSaverBuilder(object):
|
|||||||
per_device[canonical_device.pop()].append(saveable)
|
per_device[canonical_device.pop()].append(saveable)
|
||||||
return sorted(per_device.items(), key=lambda t: t[0])
|
return sorted(per_device.items(), key=lambda t: t[0])
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def OpListToDict(op_list, convert_variable_to_tensor=True):
|
|
||||||
"""Create a dictionary of names to operation lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
|
||||||
convert_variable_to_tensor: Whether or not to convert single Variables
|
|
||||||
with no slice info into Tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of names to the operations that must be saved under
|
|
||||||
that name. Variables with save_slice_info are grouped together under the
|
|
||||||
same key in no particular order.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the type of op_list or its elements is not supported.
|
|
||||||
ValueError: If at least two saveables share the same name.
|
|
||||||
"""
|
|
||||||
if not isinstance(op_list, (list, tuple, set)):
|
|
||||||
raise TypeError("Variables to save should be passed in a dict or a "
|
|
||||||
"list: %s" % op_list)
|
|
||||||
# When ResourceVariables are converted to Tensors, read ops are added to the
|
|
||||||
# graph. Sorting the op_list ensures that the resulting graph is always
|
|
||||||
# constructed in a deterministic way:
|
|
||||||
op_list = sorted(op_list, key=lambda x: x.name)
|
|
||||||
names_to_saveables = {}
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
for var in op_list:
|
|
||||||
if isinstance(var, BaseSaverBuilder.SaveableObject):
|
|
||||||
names_to_saveables[var.name] = var
|
|
||||||
elif isinstance(var, variables.PartitionedVariable):
|
|
||||||
if var.name in names_to_saveables:
|
|
||||||
raise ValueError("At least two variables have the same name: %s" %
|
|
||||||
var.name)
|
|
||||||
names_to_saveables[var.name] = var
|
|
||||||
elif isinstance(var, variables.Variable) and var._save_slice_info:
|
|
||||||
name = var._save_slice_info.full_name
|
|
||||||
if name in names_to_saveables:
|
|
||||||
if not isinstance(names_to_saveables[name], list):
|
|
||||||
raise ValueError("Mixing slices and non-slices with the same name: "
|
|
||||||
"%s" % name)
|
|
||||||
names_to_saveables[name].append(var)
|
|
||||||
else:
|
|
||||||
names_to_saveables[name] = [var]
|
|
||||||
elif (isinstance(var, checkpointable.CheckpointableBase)
|
|
||||||
and not isinstance(var, variables.Variable)):
|
|
||||||
checkpointable_saveables = [
|
|
||||||
(factory() if callable(factory) else factory)
|
|
||||||
for factory in var._gather_saveables_for_checkpoint().values()]
|
|
||||||
names_to_saveables.update(
|
|
||||||
BaseSaverBuilder.OpListToDict(checkpointable_saveables))
|
|
||||||
else:
|
|
||||||
if context.executing_eagerly():
|
|
||||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
|
||||||
raise ValueError(
|
|
||||||
"Can only save/restore ResourceVariables when eager execution "
|
|
||||||
"is enabled, type: %s." % type(var))
|
|
||||||
set_var = names_to_saveables.setdefault(var._shared_name, var)
|
|
||||||
if set_var is not var:
|
|
||||||
raise ValueError(
|
|
||||||
("Two different ResourceVariable objects with the same "
|
|
||||||
"shared_name '%s' were passed to the Saver. This likely means "
|
|
||||||
"that they were created in different Graphs or isolation "
|
|
||||||
"contexts, and may not be checkpointed together.") %
|
|
||||||
(var._shared_name,))
|
|
||||||
else:
|
|
||||||
if convert_variable_to_tensor:
|
|
||||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
|
||||||
var = var._graph_element # pylint: disable=protected-access
|
|
||||||
else:
|
|
||||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
|
||||||
if not BaseSaverBuilder._IsVariable(var):
|
|
||||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
|
||||||
if var.op.type == "ReadVariableOp":
|
|
||||||
name = var.op.inputs[0].op.name
|
|
||||||
else:
|
|
||||||
name = var.op.name
|
|
||||||
if name in names_to_saveables:
|
|
||||||
raise ValueError("At least two variables have the same name: %s" %
|
|
||||||
name)
|
|
||||||
names_to_saveables[name] = var
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
return names_to_saveables
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SaveableObjectsForOp(op, name):
|
|
||||||
"""Create `SaveableObject`s from an operation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: A variable, operation, or SaveableObject to coerce into a
|
|
||||||
SaveableObject.
|
|
||||||
name: A string name for the SaveableObject.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
`SaveableObject`s which together save/restore `op`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If `name` is not a string.
|
|
||||||
ValueError: For operations with no known conversion to SaveableObject.
|
|
||||||
"""
|
|
||||||
if not isinstance(name, six.string_types):
|
|
||||||
raise TypeError(
|
|
||||||
"names_to_saveables must be a dict mapping string names to "
|
|
||||||
"checkpointable operations. Name is not a string: %s" % name)
|
|
||||||
if isinstance(op, BaseSaverBuilder.SaveableObject):
|
|
||||||
yield op
|
|
||||||
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
|
|
||||||
if isinstance(op, variables.PartitionedVariable):
|
|
||||||
op = list(op)
|
|
||||||
# A set of slices.
|
|
||||||
slice_name = None
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
for variable in op:
|
|
||||||
if not isinstance(variable, variables.Variable):
|
|
||||||
raise ValueError("Slices must all be Variables: %s" % variable)
|
|
||||||
if not variable._save_slice_info:
|
|
||||||
raise ValueError("Slices must all be slices: %s" % variable)
|
|
||||||
if slice_name is None:
|
|
||||||
slice_name = variable._save_slice_info.full_name
|
|
||||||
elif slice_name != variable._save_slice_info.full_name:
|
|
||||||
raise ValueError(
|
|
||||||
"Slices must all be from the same tensor: %s != %s" %
|
|
||||||
(slice_name, variable._save_slice_info.full_name))
|
|
||||||
if variable.op.type in ["Variable", "VariableV2",
|
|
||||||
"AutoReloadVariable"]:
|
|
||||||
yield BaseSaverBuilder.VariableSaveable(
|
|
||||||
variable, variable._save_slice_info.spec, name)
|
|
||||||
else:
|
|
||||||
yield BaseSaverBuilder.ResourceVariableSaveable(
|
|
||||||
variable, variable._save_slice_info.spec, name)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
|
|
||||||
op, variables.Variable):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
for attr, factory in op._gather_saveables_for_checkpoint().items():
|
|
||||||
if attr == checkpointable.VARIABLE_VALUE_KEY:
|
|
||||||
# Keep original name for classes masquerading as variables.
|
|
||||||
full_name = name
|
|
||||||
else:
|
|
||||||
full_name = name + "_" + attr
|
|
||||||
op = (factory(full_name) if callable(factory) else factory)
|
|
||||||
for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
|
|
||||||
yield op
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
else:
|
|
||||||
# A variable or tensor.
|
|
||||||
if context.executing_eagerly():
|
|
||||||
if not isinstance(op, resource_variable_ops.ResourceVariable):
|
|
||||||
raise ValueError("Can only save/restore ResourceVariable eager "
|
|
||||||
"mode is enabled, type: %s." % type(op))
|
|
||||||
yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
|
|
||||||
else:
|
|
||||||
if isinstance(op, resource_variable_ops.ResourceVariable):
|
|
||||||
variable = op._graph_element # pylint: disable=protected-access
|
|
||||||
else:
|
|
||||||
variable = ops.internal_convert_to_tensor(op, as_ref=True)
|
|
||||||
if not BaseSaverBuilder._IsVariable(variable):
|
|
||||||
raise TypeError("names_to_saveables must be a dict mapping string "
|
|
||||||
"names to Tensors/Variables. Not a variable: %s" %
|
|
||||||
variable)
|
|
||||||
if variable.op.type in ["Variable", "VariableV2",
|
|
||||||
"AutoReloadVariable"]:
|
|
||||||
yield BaseSaverBuilder.VariableSaveable(variable, "", name)
|
|
||||||
else:
|
|
||||||
yield BaseSaverBuilder.ResourceVariableSaveable(
|
|
||||||
variable, "", name)
|
|
||||||
|
|
||||||
def _ValidateAndSliceInputs(self, names_to_saveables):
|
|
||||||
"""Returns the variables and names that will be used for a Saver.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
names_to_saveables: A dict (k, v) where k is the name of an operation and
|
|
||||||
v is an operation to save or a BaseSaverBuilder.Saver.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of BaseSaverBuilder.SaveableObject objects.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If any of the keys are not strings or any of the
|
|
||||||
values are not one of Tensor or Variable or a checkpointable operation.
|
|
||||||
ValueError: If the same operation is given in more than one value
|
|
||||||
(this also applies to slices of SlicedVariables).
|
|
||||||
"""
|
|
||||||
if not isinstance(names_to_saveables, dict):
|
|
||||||
names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
|
|
||||||
|
|
||||||
saveables = []
|
|
||||||
seen_ops = set()
|
|
||||||
for name, op in sorted(names_to_saveables.items(),
|
|
||||||
# Avoid comparing ops, sort only by name.
|
|
||||||
key=lambda x: x[0]):
|
|
||||||
for converted_saveable_object in self.SaveableObjectsForOp(op, name):
|
|
||||||
self._AddSaveable(saveables, seen_ops, converted_saveable_object)
|
|
||||||
return saveables
|
|
||||||
|
|
||||||
def _AddSaveable(self, saveables, seen_ops, saveable):
|
|
||||||
"""Adds the saveable to the saveables list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
saveables: List to append the SaveableObject to.
|
|
||||||
seen_ops: Set of the ops of the saveables already processed. Used to
|
|
||||||
check that each saveable is only saved once.
|
|
||||||
saveable: The saveable.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the saveable has already been processed.
|
|
||||||
"""
|
|
||||||
if saveable.op in seen_ops:
|
|
||||||
raise ValueError("The same saveable will be restored with two names: %s" %
|
|
||||||
saveable.name)
|
|
||||||
saveables.append(saveable)
|
|
||||||
seen_ops.add(saveable.op)
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
names_to_saveables,
|
names_to_saveables,
|
||||||
reshape=False,
|
reshape=False,
|
||||||
@ -775,7 +483,8 @@ class BaseSaverBuilder(object):
|
|||||||
raise ValueError("save and restore operations need to be built together "
|
raise ValueError("save and restore operations need to be built together "
|
||||||
" when eager execution is not enabled.")
|
" when eager execution is not enabled.")
|
||||||
|
|
||||||
saveables = self._ValidateAndSliceInputs(names_to_saveables)
|
saveables = saveable_object_util.validate_and_slice_inputs(
|
||||||
|
names_to_saveables)
|
||||||
if max_to_keep is None:
|
if max_to_keep is None:
|
||||||
max_to_keep = 0
|
max_to_keep = 0
|
||||||
|
|
||||||
@ -1910,7 +1619,7 @@ def saver_from_object_based_checkpoint(
|
|||||||
if builder is None:
|
if builder is None:
|
||||||
builder = BulkSaverBuilder()
|
builder = BulkSaverBuilder()
|
||||||
|
|
||||||
saveables = builder._ValidateAndSliceInputs(var_list) # pylint: disable=protected-access
|
saveables = saveable_object_util.validate_and_slice_inputs(var_list)
|
||||||
current_names = set()
|
current_names = set()
|
||||||
for saveable in saveables:
|
for saveable in saveables:
|
||||||
for spec in saveable.specs:
|
for spec in saveable.specs:
|
||||||
|
55
tensorflow/python/training/saving/BUILD
Normal file
55
tensorflow/python/training/saving/BUILD
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Description:
|
||||||
|
# Low-level utilities for reading and writing checkpoints.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "functional_saver",
|
||||||
|
srcs = ["functional_saver.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":saveable_object",
|
||||||
|
":saveable_object_util",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "functional_saver_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = [
|
||||||
|
"functional_saver_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
":functional_saver",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "saveable_object",
|
||||||
|
srcs = ["saveable_object.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "saveable_object_util",
|
||||||
|
srcs = ["saveable_object_util.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:resource_variable_ops",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
101
tensorflow/python/training/saving/functional_saver.py
Normal file
101
tensorflow/python/training/saving/functional_saver.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Saves and restore variables inside traced @tf.functions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.training.saving import saveable_object
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
|
|
||||||
|
class Saver(object):
|
||||||
|
"""A minimal utility class for saving and restoring checkpoints.
|
||||||
|
|
||||||
|
Note that this is a low-level utility which stores Tensors in the keys
|
||||||
|
specified by `SaveableObject`s. Higher-level utilities for object-based
|
||||||
|
checkpointing are built on top of it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, saveable_objects):
|
||||||
|
"""Specify a list of `SaveableObject`s to save and restore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saveable_objects: A list of `SaveableObject`s.
|
||||||
|
"""
|
||||||
|
saveable_objects = list(saveable_objects)
|
||||||
|
for saveable in saveable_objects:
|
||||||
|
if not isinstance(saveable, saveable_object.SaveableObject):
|
||||||
|
raise ValueError(
|
||||||
|
"Saver expected a list of SaveableObjects, got %s." % (saveable,))
|
||||||
|
self._saveable_objects = saveable_objects
|
||||||
|
|
||||||
|
# TODO(b/120569892): Use tf.function here
|
||||||
|
def save(self, file_prefix):
|
||||||
|
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_prefix: A string or scalar string Tensor containing the prefix to
|
||||||
|
save under.
|
||||||
|
Returns:
|
||||||
|
A scalar string Tensor containing `file_prefix` with control dependencies
|
||||||
|
on the save ops.
|
||||||
|
"""
|
||||||
|
tensor_names = []
|
||||||
|
tensors = []
|
||||||
|
tensor_slices = []
|
||||||
|
for saveable in self._saveable_objects:
|
||||||
|
for spec in saveable.specs:
|
||||||
|
tensor_names.append(spec.name)
|
||||||
|
tensors.append(spec.tensor)
|
||||||
|
tensor_slices.append(spec.slice_spec)
|
||||||
|
with ops.control_dependencies(
|
||||||
|
[io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)]):
|
||||||
|
return array_ops.identity(file_prefix)
|
||||||
|
|
||||||
|
# TODO(b/120569892): Use tf.function here
|
||||||
|
def restore(self, file_prefix):
|
||||||
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_prefix: A string or scalar string Tensor containing the prefix for
|
||||||
|
files to read from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An operation which restores the `Saver`'s `SaveableObject`s when run, or
|
||||||
|
None if executing eagerly.
|
||||||
|
"""
|
||||||
|
restore_ops = []
|
||||||
|
for saveable in self._saveable_objects:
|
||||||
|
if saveable.device:
|
||||||
|
device = saveable_object_util.set_cpu0(saveable.device)
|
||||||
|
else:
|
||||||
|
device = None
|
||||||
|
with ops.device(device):
|
||||||
|
tensors = []
|
||||||
|
for spec in saveable.specs:
|
||||||
|
tensors.append(
|
||||||
|
io_ops.restore_v2(
|
||||||
|
file_prefix,
|
||||||
|
[spec.name],
|
||||||
|
[spec.slice_spec],
|
||||||
|
[spec.dtype])[0])
|
||||||
|
restore_ops.append(saveable.restore(tensors, restored_shapes=None))
|
||||||
|
return control_flow_ops.group(restore_ops)
|
50
tensorflow/python/training/saving/functional_saver_test.py
Normal file
50
tensorflow/python/training/saving/functional_saver_test.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Tests for the functional saver."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.training.saving import functional_saver
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
|
|
||||||
|
class SaverTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_resource_variable(self):
|
||||||
|
v1 = resource_variable_ops.ResourceVariable(2.)
|
||||||
|
saver = functional_saver.Saver(
|
||||||
|
saveable_object_util.saveable_objects_for_op(v1, "x"))
|
||||||
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
|
save_path = saver.save(constant_op.constant(prefix))
|
||||||
|
v1.assign(1.)
|
||||||
|
saver.restore(save_path)
|
||||||
|
self.assertEqual(2., self.evaluate(v1))
|
||||||
|
|
||||||
|
v2 = resource_variable_ops.ResourceVariable(3.)
|
||||||
|
second_saver = functional_saver.Saver(
|
||||||
|
saveable_object_util.saveable_objects_for_op(v2, "x"))
|
||||||
|
second_saver.restore(save_path)
|
||||||
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
340
tensorflow/python/training/saving/saveable_object_util.py
Normal file
340
tensorflow/python/training/saving/saveable_object_util.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for working with and creating SaveableObjects."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import device as pydev
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||||
|
from tensorflow.python.training.saving import saveable_object
|
||||||
|
|
||||||
|
|
||||||
|
# Op names which identify variable reads which should be saved.
|
||||||
|
_VARIABLE_OPS = set(["Variable",
|
||||||
|
"VariableV2",
|
||||||
|
"AutoReloadVariable",
|
||||||
|
"VarHandleOp",
|
||||||
|
"ReadVariableOp"])
|
||||||
|
|
||||||
|
|
||||||
|
def set_cpu0(device_string):
|
||||||
|
"""Creates a new device string based on `device_string` but using /CPU:0.
|
||||||
|
|
||||||
|
If the device is already on /CPU:0, this is a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_string: A device string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A device string.
|
||||||
|
"""
|
||||||
|
parsed_device = pydev.DeviceSpec.from_string(device_string)
|
||||||
|
parsed_device.device_type = "CPU"
|
||||||
|
parsed_device.device_index = 0
|
||||||
|
return parsed_device.to_string()
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceVariableSaveable(saveable_object.SaveableObject):
|
||||||
|
"""SaveableObject implementation that handles reference variables."""
|
||||||
|
|
||||||
|
def __init__(self, var, slice_spec, name):
|
||||||
|
spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
|
||||||
|
super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
|
||||||
|
|
||||||
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
|
restored_tensor = restored_tensors[0]
|
||||||
|
if restored_shapes is not None:
|
||||||
|
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||||
|
return state_ops.assign(
|
||||||
|
self.op,
|
||||||
|
restored_tensor,
|
||||||
|
validate_shape=restored_shapes is None and
|
||||||
|
self.op.get_shape().is_fully_defined())
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceVariableSaveable(saveable_object.SaveableObject):
|
||||||
|
"""SaveableObject implementation that handles ResourceVariables."""
|
||||||
|
|
||||||
|
def __init__(self, var, slice_spec, name):
|
||||||
|
self._var_device = var.device
|
||||||
|
self._var_shape = var.shape
|
||||||
|
if isinstance(var, ops.Tensor):
|
||||||
|
self.handle_op = var.op.inputs[0]
|
||||||
|
tensor = var
|
||||||
|
elif isinstance(var, resource_variable_ops.ResourceVariable):
|
||||||
|
|
||||||
|
def _read_variable_closure(v):
|
||||||
|
def f():
|
||||||
|
with ops.device(v.device):
|
||||||
|
x = v.read_value()
|
||||||
|
# To allow variables placed on non-CPU devices to be checkpointed,
|
||||||
|
# we copy them to CPU on the same machine first.
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
return array_ops.identity(x)
|
||||||
|
return f
|
||||||
|
|
||||||
|
self.handle_op = var.handle
|
||||||
|
tensor = _read_variable_closure(var)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Saveable is neither a resource variable nor a read operation."
|
||||||
|
" Got: %s" % repr(var))
|
||||||
|
spec = saveable_object.SaveSpec(tensor, slice_spec, name,
|
||||||
|
dtype=var.dtype)
|
||||||
|
super(ResourceVariableSaveable, self).__init__(var, [spec], name)
|
||||||
|
|
||||||
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
|
restored_tensor = restored_tensors[0]
|
||||||
|
if restored_shapes is not None:
|
||||||
|
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||||
|
# Copy the restored tensor to the variable's device.
|
||||||
|
with ops.device(self._var_device):
|
||||||
|
restored_tensor = array_ops.identity(restored_tensor)
|
||||||
|
return resource_variable_ops.shape_safe_assign_variable_handle(
|
||||||
|
self.handle_op, self._var_shape, restored_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_comes_from_variable(v):
|
||||||
|
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def saveable_objects_for_op(op, name):
|
||||||
|
"""Create `SaveableObject`s from an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: A variable, operation, or SaveableObject to coerce into a
|
||||||
|
SaveableObject.
|
||||||
|
name: A string name for the SaveableObject.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
`SaveableObject`s which together save/restore `op`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `name` is not a string.
|
||||||
|
ValueError: For operations with no known conversion to SaveableObject.
|
||||||
|
"""
|
||||||
|
if not isinstance(name, six.string_types):
|
||||||
|
raise TypeError(
|
||||||
|
"names_to_saveables must be a dict mapping string names to "
|
||||||
|
"checkpointable operations. Name is not a string: %s" % name)
|
||||||
|
if isinstance(op, saveable_object.SaveableObject):
|
||||||
|
yield op
|
||||||
|
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
|
||||||
|
if isinstance(op, variables.PartitionedVariable):
|
||||||
|
op = list(op)
|
||||||
|
# A set of slices.
|
||||||
|
slice_name = None
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for variable in op:
|
||||||
|
if not isinstance(variable, variables.Variable):
|
||||||
|
raise ValueError("Slices must all be Variables: %s" % variable)
|
||||||
|
if not variable._save_slice_info:
|
||||||
|
raise ValueError("Slices must all be slices: %s" % variable)
|
||||||
|
if slice_name is None:
|
||||||
|
slice_name = variable._save_slice_info.full_name
|
||||||
|
elif slice_name != variable._save_slice_info.full_name:
|
||||||
|
raise ValueError(
|
||||||
|
"Slices must all be from the same tensor: %s != %s" %
|
||||||
|
(slice_name, variable._save_slice_info.full_name))
|
||||||
|
if variable.op.type in ["Variable", "VariableV2",
|
||||||
|
"AutoReloadVariable"]:
|
||||||
|
yield ReferenceVariableSaveable(
|
||||||
|
variable, variable._save_slice_info.spec, name)
|
||||||
|
else:
|
||||||
|
yield ResourceVariableSaveable(
|
||||||
|
variable, variable._save_slice_info.spec, name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
|
||||||
|
op, variables.Variable):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for attr, factory in op._gather_saveables_for_checkpoint().items():
|
||||||
|
if attr == checkpointable.VARIABLE_VALUE_KEY:
|
||||||
|
# Keep original name for classes masquerading as variables.
|
||||||
|
full_name = name
|
||||||
|
else:
|
||||||
|
full_name = name + "_" + attr
|
||||||
|
op = (factory(full_name) if callable(factory) else factory)
|
||||||
|
for op in saveable_objects_for_op(op, op.name):
|
||||||
|
yield op
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
else:
|
||||||
|
# A variable or tensor.
|
||||||
|
if isinstance(op, resource_variable_ops.ResourceVariable):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if op._in_graph_mode:
|
||||||
|
variable = op._graph_element
|
||||||
|
else:
|
||||||
|
variable = op
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
yield ResourceVariableSaveable(variable, "", name)
|
||||||
|
else:
|
||||||
|
with ops.init_scope():
|
||||||
|
if context.executing_eagerly():
|
||||||
|
raise ValueError("Can only save/restore ResourceVariables when "
|
||||||
|
"executing eagerly, got type: %s." % type(op))
|
||||||
|
|
||||||
|
variable = ops.internal_convert_to_tensor(op, as_ref=True)
|
||||||
|
if not _tensor_comes_from_variable(variable):
|
||||||
|
raise TypeError("names_to_saveables must be a dict mapping string "
|
||||||
|
"names to Tensors/Variables. Not a variable: %s" %
|
||||||
|
variable)
|
||||||
|
if variable.op.type in ["Variable", "VariableV2",
|
||||||
|
"AutoReloadVariable"]:
|
||||||
|
yield ReferenceVariableSaveable(variable, "", name)
|
||||||
|
else:
|
||||||
|
yield ResourceVariableSaveable(
|
||||||
|
variable, "", name)
|
||||||
|
|
||||||
|
|
||||||
|
def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||||
|
"""Create a dictionary of names to operation lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||||
|
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||||
|
with no slice info into Tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of names to the operations that must be saved under
|
||||||
|
that name. Variables with save_slice_info are grouped together under the
|
||||||
|
same key in no particular order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the type of op_list or its elements is not supported.
|
||||||
|
ValueError: If at least two saveables share the same name.
|
||||||
|
"""
|
||||||
|
if not isinstance(op_list, (list, tuple, set)):
|
||||||
|
raise TypeError("Variables to save should be passed in a dict or a "
|
||||||
|
"list: %s" % op_list)
|
||||||
|
# When ResourceVariables are converted to Tensors, read ops are added to the
|
||||||
|
# graph. Sorting the op_list ensures that the resulting graph is always
|
||||||
|
# constructed in a deterministic way:
|
||||||
|
op_list = sorted(op_list, key=lambda x: x.name)
|
||||||
|
names_to_saveables = {}
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for var in op_list:
|
||||||
|
if isinstance(var, saveable_object.SaveableObject):
|
||||||
|
names_to_saveables[var.name] = var
|
||||||
|
elif isinstance(var, variables.PartitionedVariable):
|
||||||
|
if var.name in names_to_saveables:
|
||||||
|
raise ValueError("At least two variables have the same name: %s" %
|
||||||
|
var.name)
|
||||||
|
names_to_saveables[var.name] = var
|
||||||
|
elif isinstance(var, variables.Variable) and var._save_slice_info:
|
||||||
|
name = var._save_slice_info.full_name
|
||||||
|
if name in names_to_saveables:
|
||||||
|
if not isinstance(names_to_saveables[name], list):
|
||||||
|
raise ValueError("Mixing slices and non-slices with the same name: "
|
||||||
|
"%s" % name)
|
||||||
|
names_to_saveables[name].append(var)
|
||||||
|
else:
|
||||||
|
names_to_saveables[name] = [var]
|
||||||
|
elif (isinstance(var, checkpointable.CheckpointableBase)
|
||||||
|
and not isinstance(var, variables.Variable)):
|
||||||
|
checkpointable_saveables = [
|
||||||
|
(factory() if callable(factory) else factory)
|
||||||
|
for factory in var._gather_saveables_for_checkpoint().values()]
|
||||||
|
names_to_saveables.update(
|
||||||
|
op_list_to_dict(checkpointable_saveables))
|
||||||
|
else:
|
||||||
|
if context.executing_eagerly():
|
||||||
|
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
||||||
|
raise ValueError(
|
||||||
|
"Can only save/restore ResourceVariables when eager execution "
|
||||||
|
"is enabled, type: %s." % type(var))
|
||||||
|
set_var = names_to_saveables.setdefault(var._shared_name, var)
|
||||||
|
if set_var is not var:
|
||||||
|
raise ValueError(
|
||||||
|
("Two different ResourceVariable objects with the same "
|
||||||
|
"shared_name '%s' were passed to the Saver. This likely means "
|
||||||
|
"that they were created in different Graphs or isolation "
|
||||||
|
"contexts, and may not be checkpointed together.") %
|
||||||
|
(var._shared_name,))
|
||||||
|
else:
|
||||||
|
if convert_variable_to_tensor:
|
||||||
|
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||||
|
var = var._graph_element # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||||
|
if not _tensor_comes_from_variable(var):
|
||||||
|
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||||
|
if var.op.type == "ReadVariableOp":
|
||||||
|
name = var.op.inputs[0].op.name
|
||||||
|
else:
|
||||||
|
name = var.op.name
|
||||||
|
if name in names_to_saveables:
|
||||||
|
raise ValueError("At least two variables have the same name: %s" %
|
||||||
|
name)
|
||||||
|
names_to_saveables[name] = var
|
||||||
|
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return names_to_saveables
|
||||||
|
|
||||||
|
|
||||||
|
def _add_saveable(saveables, seen_ops, saveable):
|
||||||
|
"""Adds the saveable to the saveables list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saveables: List to append the SaveableObject to.
|
||||||
|
seen_ops: Set of the ops of the saveables already processed. Used to
|
||||||
|
check that each saveable is only saved once.
|
||||||
|
saveable: The saveable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the saveable has already been processed.
|
||||||
|
"""
|
||||||
|
if saveable.op in seen_ops:
|
||||||
|
raise ValueError("The same saveable will be restored with two names: %s" %
|
||||||
|
saveable.name)
|
||||||
|
saveables.append(saveable)
|
||||||
|
seen_ops.add(saveable.op)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_slice_inputs(names_to_saveables):
|
||||||
|
"""Returns the variables and names that will be used for a Saver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
names_to_saveables: A dict (k, v) where k is the name of an operation and
|
||||||
|
v is an operation to save or a BaseSaverBuilder.Saver.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of SaveableObjects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If any of the keys are not strings or any of the
|
||||||
|
values are not one of Tensor or Variable or a checkpointable operation.
|
||||||
|
ValueError: If the same operation is given in more than one value
|
||||||
|
(this also applies to slices of SlicedVariables).
|
||||||
|
"""
|
||||||
|
if not isinstance(names_to_saveables, dict):
|
||||||
|
names_to_saveables = op_list_to_dict(names_to_saveables)
|
||||||
|
|
||||||
|
saveables = []
|
||||||
|
seen_ops = set()
|
||||||
|
for name, op in sorted(names_to_saveables.items(),
|
||||||
|
# Avoid comparing ops, sort only by name.
|
||||||
|
key=lambda x: x[0]):
|
||||||
|
for converted_saveable_object in saveable_objects_for_op(op, name):
|
||||||
|
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
||||||
|
return saveables
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import variables as variables_lib
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_ops
|
from tensorflow.python.training import checkpoint_ops
|
||||||
from tensorflow.python.training import checkpoint_utils
|
from tensorflow.python.training import checkpoint_utils
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def _infer_var_name(var):
|
|||||||
Returns:
|
Returns:
|
||||||
Name of the `var`
|
Name of the `var`
|
||||||
"""
|
"""
|
||||||
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
|
name_to_var_dict = saveable_object_util.op_list_to_dict(var)
|
||||||
if len(name_to_var_dict) > 1:
|
if len(name_to_var_dict) > 1:
|
||||||
raise TypeError("`var` = %s passed as arg violates the constraints. "
|
raise TypeError("`var` = %s passed as arg violates the constraints. "
|
||||||
"name_to_var_dict = %s" % (var, name_to_var_dict))
|
"name_to_var_dict = %s" % (var, name_to_var_dict))
|
||||||
|
Loading…
Reference in New Issue
Block a user