Add a functional saver, use it for object-based checkpointing
Pulls some utilities out of saver.py which are necessary to actually use it. The functional saver takes only SaveableObjects, so these are utilities for taking a list of whatever users pass in and converting them to those. One other code move for object-based checkpointing to avoid circular imports. Applications which need a SaverDef still use the old Saver. Serialization to SaverDef will be added to this saver in a followup. Does not actually wrap the new Saver's methods in @tf.function yet, since there are memory issues which need to be fixed first. PiperOrigin-RevId: 224561069
This commit is contained in:
parent
4f543e588a
commit
66ca3cd10d
@ -30,6 +30,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
|
||||
LOCAL_VARIABLE_NAME = 'local_center_variable'
|
||||
GLOBAL_VARIABLE_NAME = 'global_center_variable'
|
||||
@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
|
||||
if var_list is None:
|
||||
var_list = variables.trainable_variables()
|
||||
if not isinstance(var_list, dict):
|
||||
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
|
||||
var_list = saveable_object_util.op_list_to_dict(var_list)
|
||||
|
||||
swapped_var_list = {}
|
||||
for key, var in var_list.items():
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import moving_averages
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
|
||||
|
||||
class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
if var_list is None:
|
||||
var_list = variables.global_variables()
|
||||
if not isinstance(var_list, dict):
|
||||
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
|
||||
var_list = saveable_object_util.op_list_to_dict(var_list)
|
||||
|
||||
v_name_to_tensor = {}
|
||||
for k, tensor_or_list in six.iteritems(var_list):
|
||||
|
@ -3515,13 +3515,13 @@ py_library(
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"training/checkpointable/**/*.py",
|
||||
"training/saving/**/*.py",
|
||||
# The following targets have their own build rules (same name as the
|
||||
# file):
|
||||
"training/basic_session_run_hooks.py",
|
||||
"training/checkpoint_management.py",
|
||||
"training/distribute.py",
|
||||
"training/distribution_strategy_context.py",
|
||||
"training/saveable_object.py",
|
||||
"training/saver.py",
|
||||
"training/session_run_hook.py",
|
||||
"training/training_util.py",
|
||||
@ -3596,12 +3596,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_object",
|
||||
srcs = ["training/saveable_object.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "checkpoint_management",
|
||||
srcs = ["training/checkpoint_management.py"],
|
||||
@ -3655,7 +3649,6 @@ py_library(
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":resource_variable_ops",
|
||||
":saveable_object",
|
||||
":session",
|
||||
":state_ops",
|
||||
":string_ops",
|
||||
@ -3665,6 +3658,8 @@ py_library(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -311,10 +311,10 @@ def _set_checkpoint_initializer(variable,
|
||||
restore_op = io_ops.restore_v2(
|
||||
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||
|
||||
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
|
||||
names_to_saveables = saveable_object_util.op_list_to_dict([variable])
|
||||
saveable_objects = []
|
||||
for name, op in names_to_saveables.items():
|
||||
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
|
||||
for s in saveable_object_util.saveable_objects_for_op(op, name):
|
||||
saveable_objects.append(s)
|
||||
|
||||
assert len(saveable_objects) == 1 # Should be only one variable.
|
||||
|
@ -25,9 +25,9 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops_gen",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:saveable_object",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
],
|
||||
)
|
||||
|
||||
@ -114,7 +114,6 @@ py_library(
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:io_ops_gen",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:saveable_object",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
@ -123,6 +122,9 @@ py_library(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/training/saving:functional_saver",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,7 +25,6 @@ import weakref
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -34,7 +33,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops as io_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -374,41 +373,10 @@ class _CheckpointPosition(object):
|
||||
eagerly.
|
||||
"""
|
||||
(restore_ops,
|
||||
named_saveables,
|
||||
tensor_saveables,
|
||||
python_saveables) = self._gather_ops_or_named_saveables()
|
||||
|
||||
# Eagerly run restorations for Python state.
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(
|
||||
self._checkpoint.save_path_string)
|
||||
for saveable in python_saveables:
|
||||
spec_names = [spec.name for spec in saveable.specs]
|
||||
saveable.python_restore(
|
||||
[reader.get_tensor(name) for name in spec_names])
|
||||
|
||||
# If we have new SaveableObjects, extract and cache restore ops.
|
||||
if named_saveables:
|
||||
validated_saveables = (
|
||||
self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access
|
||||
validated_names = set(saveable.name for saveable in validated_saveables)
|
||||
if set(named_saveables.keys()) != validated_names:
|
||||
raise AssertionError(
|
||||
("Saveable keys changed when validating. Got back %s, was "
|
||||
"expecting %s") % (named_saveables.keys(), validated_names))
|
||||
all_tensors = self._checkpoint.builder.bulk_restore(
|
||||
filename_tensor=self._checkpoint.save_path_tensor,
|
||||
saveables=validated_saveables, preferred_shard=-1,
|
||||
restore_sequentially=False)
|
||||
saveable_index = 0
|
||||
for saveable in validated_saveables:
|
||||
num_specs = len(saveable.specs)
|
||||
saveable_tensors = all_tensors[
|
||||
saveable_index:saveable_index + num_specs]
|
||||
saveable_index += num_specs
|
||||
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
|
||||
if not context.executing_eagerly():
|
||||
assert saveable.name not in self._checkpoint.restore_ops_by_name
|
||||
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
|
||||
restore_ops.append(restore_op)
|
||||
restore_ops.extend(self._checkpoint.restore_saveables(
|
||||
tensor_saveables, python_saveables))
|
||||
return restore_ops
|
||||
|
||||
@property
|
||||
|
@ -40,11 +40,14 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
from tensorflow.python.training import saveable_object as saveable_object_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import saver as v1_saver_lib
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -89,7 +92,6 @@ class _CheckpointRestoreCoordinator(object):
|
||||
referenced every restore (e.g. for Python state); otherwise they would
|
||||
create their own ops every restore.
|
||||
"""
|
||||
self.builder = saver_lib.BulkSaverBuilder()
|
||||
self.object_graph_proto = object_graph_proto
|
||||
self.restore_uid = ops.uid()
|
||||
# Maps from objects to lists of attributes which were in the checkpoint but
|
||||
@ -144,6 +146,57 @@ class _CheckpointRestoreCoordinator(object):
|
||||
if self.new_restore_ops_callback:
|
||||
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
|
||||
|
||||
def restore_saveables(self, tensor_saveables, python_saveables):
|
||||
"""Run or build restore operations for SaveableObjects.
|
||||
|
||||
Args:
|
||||
tensor_saveables: `SaveableObject`s which correspond to Tensors.
|
||||
python_saveables: `PythonStateSaveable`s which correspond to Python
|
||||
values.
|
||||
|
||||
Returns:
|
||||
When graph building, a list of restore operations, either cached or newly
|
||||
created, to restore `tensor_saveables`.
|
||||
"""
|
||||
restore_ops = []
|
||||
# Eagerly run restorations for Python state.
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(
|
||||
self.save_path_string)
|
||||
for saveable in python_saveables:
|
||||
spec_names = [spec.name for spec in saveable.specs]
|
||||
saveable.python_restore(
|
||||
[reader.get_tensor(name) for name in spec_names])
|
||||
|
||||
# If we have new SaveableObjects, extract and cache restore ops.
|
||||
if tensor_saveables:
|
||||
validated_saveables = saveable_object_util.validate_and_slice_inputs(
|
||||
tensor_saveables)
|
||||
validated_names = set(saveable.name for saveable in validated_saveables)
|
||||
if set(tensor_saveables.keys()) != validated_names:
|
||||
raise AssertionError(
|
||||
("Saveable keys changed when validating. Got back %s, was "
|
||||
"expecting %s") % (tensor_saveables.keys(), validated_names))
|
||||
for saveable in validated_saveables:
|
||||
if saveable.device:
|
||||
device = saveable_object_util.set_cpu0(saveable.device)
|
||||
else:
|
||||
device = None
|
||||
with ops.device(device):
|
||||
tensors = []
|
||||
for spec in saveable.specs:
|
||||
tensors.append(
|
||||
io_ops.restore_v2(
|
||||
self.save_path_tensor,
|
||||
[spec.name],
|
||||
[spec.slice_spec],
|
||||
[spec.dtype])[0])
|
||||
restore_op = saveable.restore(tensors, restored_shapes=None)
|
||||
if not context.executing_eagerly():
|
||||
assert saveable.name not in self.restore_ops_by_name
|
||||
self.restore_ops_by_name[saveable.name] = restore_op
|
||||
restore_ops.append(restore_op)
|
||||
return restore_ops
|
||||
|
||||
|
||||
class _NameBasedRestoreCoordinator(object):
|
||||
"""Keeps the status of a name-based checkpoint restore."""
|
||||
@ -183,11 +236,11 @@ class _NameBasedRestoreCoordinator(object):
|
||||
continue
|
||||
else:
|
||||
saveable = saveable_factory
|
||||
names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict(
|
||||
names_to_saveables = saveable_object_util.op_list_to_dict(
|
||||
[saveable],
|
||||
convert_variable_to_tensor=False)
|
||||
for name, op in names_to_saveables.items():
|
||||
for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
|
||||
for saveable_object in saveable_object_util.saveable_objects_for_op(
|
||||
op=op, name=name):
|
||||
yield saveable_object
|
||||
|
||||
@ -606,10 +659,10 @@ def _add_attributes_to_object_graph(
|
||||
# Figure out the name-based Saver's name for this variable. If it's
|
||||
# already a SaveableObject we'd just get the checkpoint key back, so
|
||||
# we leave full_name blank.
|
||||
saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
|
||||
saver_dict = saveable_object_util.op_list_to_dict(
|
||||
[maybe_saveable], convert_variable_to_tensor=False)
|
||||
full_name, = saver_dict.keys()
|
||||
saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
|
||||
saveables = tuple(saveable_object_util.saveable_objects_for_op(
|
||||
op=maybe_saveable, name=attribute.checkpoint_key))
|
||||
for saveable in saveables:
|
||||
saveable.full_name = full_name
|
||||
@ -1226,7 +1279,7 @@ class NameBasedSaverStatus(_LoadStatus):
|
||||
session = ops.get_default_session()
|
||||
with ops.device("/cpu:0"):
|
||||
saveables = self._gather_saveable_objects()
|
||||
saver_lib.Saver(saveables).restore(
|
||||
v1_saver_lib.Saver(saveables).restore(
|
||||
sess=session, save_path=self._checkpoint.save_path)
|
||||
|
||||
def initialize_or_restore(self, session=None):
|
||||
@ -1251,18 +1304,6 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
|
||||
fetches=fetches, feed_dict=feed_dict, **kwargs)
|
||||
|
||||
|
||||
def _copy_saver_with_new_var_list(old_saver, new_var_list):
|
||||
"""Copy a `tf.train.Saver`'s state to a new Saver with different variables."""
|
||||
new_saver = saver_lib.Saver(var_list=new_var_list, max_to_keep=None)
|
||||
# TODO(allenl): Move to copying functionality to Saver?
|
||||
# pylint: disable=protected-access
|
||||
new_saver._last_checkpoints = old_saver._last_checkpoints
|
||||
new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted
|
||||
new_saver._next_checkpoint_time = old_saver._next_checkpoint_time
|
||||
# pylint: enable=protected-access
|
||||
return new_saver
|
||||
|
||||
|
||||
class CheckpointableSaver(object):
|
||||
"""Saves and restores a `Checkpointable` object and its dependencies.
|
||||
|
||||
@ -1301,7 +1342,8 @@ class CheckpointableSaver(object):
|
||||
# Op caching for save
|
||||
self._object_graph_feed_tensor = None
|
||||
self._last_save_object_graph = None
|
||||
self._last_save_saver = None
|
||||
self._file_prefix_feed_tensor = None
|
||||
self._cached_save_operation = None
|
||||
|
||||
# Op caching for restore, shared between _CheckpointRestoreCoordinators
|
||||
self._restore_op_cache = {}
|
||||
@ -1368,11 +1410,14 @@ class CheckpointableSaver(object):
|
||||
base.NoRestoreSaveable(
|
||||
tensor=object_graph_tensor,
|
||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
# TODO(allenl, haoliang): Swap in a function-based saver here.
|
||||
return saver_lib.Saver(
|
||||
# TODO(allenl): Swap in a function-based saver here once it can serialize
|
||||
# to a SaverDef.
|
||||
return v1_saver_lib.Saver(
|
||||
var_list=named_saveable_objects, max_to_keep=None)
|
||||
|
||||
def _prepare_save(self,
|
||||
def _save_cached_when_graph_building(
|
||||
self,
|
||||
file_prefix,
|
||||
object_graph_tensor=None,
|
||||
saveable_object_cache=None):
|
||||
"""Create or retrieve save ops.
|
||||
@ -1383,15 +1428,17 @@ class CheckpointableSaver(object):
|
||||
unnecessarily re-creating save ops.
|
||||
|
||||
Args:
|
||||
file_prefix: The prefix for saved checkpoint files.
|
||||
object_graph_tensor: A `Tensor` to which the current object graph will be
|
||||
fed.
|
||||
saveable_object_cache: A dictionary; if specified, used to cache
|
||||
`SaveableObject`s.
|
||||
|
||||
Returns:
|
||||
A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
|
||||
to feed when running save ops. The feed dict contains the current object
|
||||
graph and any Python state to be saved in the checkpoint.
|
||||
A two-element tuple with a filename tensor and a feed_dict of tensors to
|
||||
feed when running it (if graph building). The feed dict contains the
|
||||
current object graph and any Python state to be saved in the
|
||||
checkpoint. When executing eagerly only the first argument is meaningful.
|
||||
"""
|
||||
(named_saveable_objects, graph_proto,
|
||||
feed_additions) = self._gather_saveables(
|
||||
@ -1403,15 +1450,11 @@ class CheckpointableSaver(object):
|
||||
# constructors. That means the Saver needs to be copied with a new
|
||||
# var_list.
|
||||
or context.executing_eagerly()):
|
||||
if self._last_save_object_graph is not None:
|
||||
self._last_save_saver = _copy_saver_with_new_var_list(
|
||||
old_saver=self._last_save_saver,
|
||||
new_var_list=named_saveable_objects)
|
||||
else:
|
||||
self._last_save_saver = saver_lib.Saver(
|
||||
var_list=named_saveable_objects, max_to_keep=None)
|
||||
saver = functional_saver.Saver(named_saveable_objects)
|
||||
with ops.device("/cpu:0"):
|
||||
self._cached_save_operation = saver.save(file_prefix)
|
||||
self._last_save_object_graph = graph_proto
|
||||
return self._last_save_saver, feed_additions
|
||||
return self._cached_save_operation, feed_additions
|
||||
|
||||
def save(self, file_prefix, checkpoint_number=None, session=None):
|
||||
"""Save a training checkpoint.
|
||||
@ -1435,36 +1478,42 @@ class CheckpointableSaver(object):
|
||||
Returns:
|
||||
The full path to the checkpoint.
|
||||
"""
|
||||
feed_additions = {}
|
||||
feed_dict = {}
|
||||
graph_building = not context.executing_eagerly()
|
||||
if checkpoint_number:
|
||||
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
|
||||
if graph_building:
|
||||
if self._object_graph_feed_tensor is None:
|
||||
with ops.device("/cpu:0"):
|
||||
self._object_graph_feed_tensor = constant_op.constant(
|
||||
"", dtype=dtypes.string)
|
||||
self._file_prefix_feed_tensor = constant_op.constant(
|
||||
"", dtype=dtypes.string)
|
||||
object_graph_tensor = self._object_graph_feed_tensor
|
||||
file_prefix_tensor = self._file_prefix_feed_tensor
|
||||
feed_dict[file_prefix_tensor] = file_prefix
|
||||
else:
|
||||
with ops.device("/cpu:0"):
|
||||
file_prefix_tensor = constant_op.constant(
|
||||
file_prefix, dtype=dtypes.string)
|
||||
object_graph_tensor = None
|
||||
|
||||
saver, new_feed_additions = self._prepare_save(
|
||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
||||
file_prefix=file_prefix_tensor,
|
||||
object_graph_tensor=object_graph_tensor,
|
||||
saveable_object_cache=self._saveable_object_cache)
|
||||
if new_feed_additions:
|
||||
feed_additions.update(new_feed_additions)
|
||||
feed_dict.update(new_feed_additions)
|
||||
if not graph_building:
|
||||
session = None
|
||||
elif session is None:
|
||||
session = ops.get_default_session()
|
||||
|
||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||
with ops.device("/cpu:0"):
|
||||
save_path = saver.save(
|
||||
sess=_SessionWithFeedDictAdditions(
|
||||
session=session, feed_additions=feed_additions),
|
||||
save_path=file_prefix,
|
||||
write_meta_graph=False,
|
||||
write_state=False,
|
||||
global_step=checkpoint_number)
|
||||
if session:
|
||||
save_path = session.run(save_path, feed_dict=feed_dict)
|
||||
else:
|
||||
save_path = save_path.numpy()
|
||||
return save_path
|
||||
|
||||
def restore(self, save_path):
|
||||
@ -1753,9 +1802,9 @@ class Checkpoint(tracking.Checkpointable):
|
||||
Returns:
|
||||
The full path to the checkpoint (i.e. `file_prefix`).
|
||||
"""
|
||||
return self._saver.save(
|
||||
return compat.as_str(self._saver.save(
|
||||
file_prefix=file_prefix,
|
||||
session=session)
|
||||
session=session))
|
||||
|
||||
@property
|
||||
def save_counter(self):
|
||||
|
@ -14,7 +14,11 @@
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
"""Save and restore variables."""
|
||||
"""Save and restore variables.
|
||||
|
||||
Symbols in this file are deprecated. See replacements in
|
||||
tensorflow/python/training/checkpointable and tensorflow/python/training/saving.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -25,7 +29,6 @@ import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
@ -42,16 +45,15 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saveable_object
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -67,31 +69,6 @@ get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
|
||||
remove_checkpoint = checkpoint_management.remove_checkpoint
|
||||
|
||||
|
||||
# Op names which identify variable reads which should be saved.
|
||||
_VARIABLE_OPS = set(["Variable",
|
||||
"VariableV2",
|
||||
"AutoReloadVariable",
|
||||
"VarHandleOp",
|
||||
"ReadVariableOp"])
|
||||
|
||||
|
||||
def _set_cpu0(device_string):
|
||||
"""Creates a new device string based on `device_string` but using /CPU:0.
|
||||
|
||||
If the device is already on /CPU:0, this is a no-op.
|
||||
|
||||
Args:
|
||||
device_string: A device string.
|
||||
|
||||
Returns:
|
||||
A device string.
|
||||
"""
|
||||
parsed_device = pydev.DeviceSpec.from_string(device_string)
|
||||
parsed_device.device_type = "CPU"
|
||||
parsed_device.device_index = 0
|
||||
return parsed_device.to_string()
|
||||
|
||||
|
||||
class BaseSaverBuilder(object):
|
||||
"""Base class for Savers.
|
||||
|
||||
@ -101,64 +78,9 @@ class BaseSaverBuilder(object):
|
||||
SaveSpec = saveable_object.SaveSpec
|
||||
SaveableObject = saveable_object.SaveableObject
|
||||
|
||||
class VariableSaveable(SaveableObject):
|
||||
"""SaveableObject implementation that handles Variables."""
|
||||
|
||||
def __init__(self, var, slice_spec, name):
|
||||
spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype)
|
||||
super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
restored_tensor = restored_tensors[0]
|
||||
if restored_shapes is not None:
|
||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||
return state_ops.assign(
|
||||
self.op,
|
||||
restored_tensor,
|
||||
validate_shape=restored_shapes is None and
|
||||
self.op.get_shape().is_fully_defined())
|
||||
|
||||
class ResourceVariableSaveable(SaveableObject):
|
||||
"""SaveableObject implementation that handles ResourceVariables."""
|
||||
|
||||
def __init__(self, var, slice_spec, name):
|
||||
self._var_device = var.device
|
||||
self._var_shape = var.shape
|
||||
if isinstance(var, ops.Tensor):
|
||||
self.handle_op = var.op.inputs[0]
|
||||
tensor = var
|
||||
elif isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
|
||||
def _read_variable_closure(v):
|
||||
def f():
|
||||
with ops.device(v.device):
|
||||
x = v.read_value()
|
||||
# To allow variables placed on non-CPU devices to be checkpointed,
|
||||
# we copy them to CPU on the same machine first.
|
||||
with ops.device("/device:CPU:0"):
|
||||
return array_ops.identity(x)
|
||||
return f
|
||||
|
||||
self.handle_op = var.handle
|
||||
tensor = _read_variable_closure(var)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Saveable is neither a resource variable nor a read operation."
|
||||
" Got: %s" % repr(var))
|
||||
spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name,
|
||||
dtype=var.dtype)
|
||||
super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
|
||||
var, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
restored_tensor = restored_tensors[0]
|
||||
if restored_shapes is not None:
|
||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||
# Copy the restored tensor to the variable's device.
|
||||
with ops.device(self._var_device):
|
||||
restored_tensor = array_ops.identity(restored_tensor)
|
||||
return resource_variable_ops.shape_safe_assign_variable_handle(
|
||||
self.handle_op, self._var_shape, restored_tensor)
|
||||
# Aliases for code which was moved but still has lots of users.
|
||||
VariableSaveable = saveable_object_util.ReferenceVariableSaveable
|
||||
ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
|
||||
|
||||
def __init__(self, write_version=saver_pb2.SaverDef.V2):
|
||||
self._write_version = write_version
|
||||
@ -224,7 +146,11 @@ class BaseSaverBuilder(object):
|
||||
del restore_sequentially
|
||||
all_tensors = []
|
||||
for saveable in saveables:
|
||||
with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
|
||||
if saveable.device:
|
||||
device = saveable_object_util.set_cpu0(saveable.device)
|
||||
else:
|
||||
device = None
|
||||
with ops.device(device):
|
||||
all_tensors.extend(
|
||||
self.restore_op(filename_tensor, saveable, preferred_shard))
|
||||
return all_tensors
|
||||
@ -336,7 +262,7 @@ class BaseSaverBuilder(object):
|
||||
last_device = None
|
||||
for shard, (device, saveables) in enumerate(per_device):
|
||||
last_device = device
|
||||
with ops.device(_set_cpu0(device)):
|
||||
with ops.device(saveable_object_util.set_cpu0(device)):
|
||||
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
|
||||
num_shards_tensor)
|
||||
sharded_prefixes.append(sharded_filename)
|
||||
@ -344,7 +270,7 @@ class BaseSaverBuilder(object):
|
||||
|
||||
with ops.control_dependencies([x.op for x in sharded_saves]):
|
||||
# Co-locates the merge step with the last device.
|
||||
with ops.device(_set_cpu0(last_device)):
|
||||
with ops.device(saveable_object_util.set_cpu0(last_device)):
|
||||
# V2 format write path consists of a metadata merge step. Once merged,
|
||||
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
||||
merge_step = gen_io_ops.merge_v2_checkpoints(
|
||||
@ -459,10 +385,6 @@ class BaseSaverBuilder(object):
|
||||
name="restore_shard"))
|
||||
return control_flow_ops.group(*sharded_restores, name="restore_all")
|
||||
|
||||
@staticmethod
|
||||
def _IsVariable(v):
|
||||
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
|
||||
|
||||
def _GroupByDevices(self, saveables):
|
||||
"""Group Variable tensor slices per device.
|
||||
|
||||
@ -490,220 +412,6 @@ class BaseSaverBuilder(object):
|
||||
per_device[canonical_device.pop()].append(saveable)
|
||||
return sorted(per_device.items(), key=lambda t: t[0])
|
||||
|
||||
@staticmethod
|
||||
def OpListToDict(op_list, convert_variable_to_tensor=True):
|
||||
"""Create a dictionary of names to operation lists.
|
||||
|
||||
Args:
|
||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||
with no slice info into Tensors.
|
||||
|
||||
Returns:
|
||||
A dictionary of names to the operations that must be saved under
|
||||
that name. Variables with save_slice_info are grouped together under the
|
||||
same key in no particular order.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of op_list or its elements is not supported.
|
||||
ValueError: If at least two saveables share the same name.
|
||||
"""
|
||||
if not isinstance(op_list, (list, tuple, set)):
|
||||
raise TypeError("Variables to save should be passed in a dict or a "
|
||||
"list: %s" % op_list)
|
||||
# When ResourceVariables are converted to Tensors, read ops are added to the
|
||||
# graph. Sorting the op_list ensures that the resulting graph is always
|
||||
# constructed in a deterministic way:
|
||||
op_list = sorted(op_list, key=lambda x: x.name)
|
||||
names_to_saveables = {}
|
||||
# pylint: disable=protected-access
|
||||
for var in op_list:
|
||||
if isinstance(var, BaseSaverBuilder.SaveableObject):
|
||||
names_to_saveables[var.name] = var
|
||||
elif isinstance(var, variables.PartitionedVariable):
|
||||
if var.name in names_to_saveables:
|
||||
raise ValueError("At least two variables have the same name: %s" %
|
||||
var.name)
|
||||
names_to_saveables[var.name] = var
|
||||
elif isinstance(var, variables.Variable) and var._save_slice_info:
|
||||
name = var._save_slice_info.full_name
|
||||
if name in names_to_saveables:
|
||||
if not isinstance(names_to_saveables[name], list):
|
||||
raise ValueError("Mixing slices and non-slices with the same name: "
|
||||
"%s" % name)
|
||||
names_to_saveables[name].append(var)
|
||||
else:
|
||||
names_to_saveables[name] = [var]
|
||||
elif (isinstance(var, checkpointable.CheckpointableBase)
|
||||
and not isinstance(var, variables.Variable)):
|
||||
checkpointable_saveables = [
|
||||
(factory() if callable(factory) else factory)
|
||||
for factory in var._gather_saveables_for_checkpoint().values()]
|
||||
names_to_saveables.update(
|
||||
BaseSaverBuilder.OpListToDict(checkpointable_saveables))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
raise ValueError(
|
||||
"Can only save/restore ResourceVariables when eager execution "
|
||||
"is enabled, type: %s." % type(var))
|
||||
set_var = names_to_saveables.setdefault(var._shared_name, var)
|
||||
if set_var is not var:
|
||||
raise ValueError(
|
||||
("Two different ResourceVariable objects with the same "
|
||||
"shared_name '%s' were passed to the Saver. This likely means "
|
||||
"that they were created in different Graphs or isolation "
|
||||
"contexts, and may not be checkpointed together.") %
|
||||
(var._shared_name,))
|
||||
else:
|
||||
if convert_variable_to_tensor:
|
||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
var = var._graph_element # pylint: disable=protected-access
|
||||
else:
|
||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||
if not BaseSaverBuilder._IsVariable(var):
|
||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||
if var.op.type == "ReadVariableOp":
|
||||
name = var.op.inputs[0].op.name
|
||||
else:
|
||||
name = var.op.name
|
||||
if name in names_to_saveables:
|
||||
raise ValueError("At least two variables have the same name: %s" %
|
||||
name)
|
||||
names_to_saveables[name] = var
|
||||
|
||||
# pylint: enable=protected-access
|
||||
return names_to_saveables
|
||||
|
||||
@staticmethod
|
||||
def SaveableObjectsForOp(op, name):
|
||||
"""Create `SaveableObject`s from an operation.
|
||||
|
||||
Args:
|
||||
op: A variable, operation, or SaveableObject to coerce into a
|
||||
SaveableObject.
|
||||
name: A string name for the SaveableObject.
|
||||
|
||||
Yields:
|
||||
`SaveableObject`s which together save/restore `op`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `name` is not a string.
|
||||
ValueError: For operations with no known conversion to SaveableObject.
|
||||
"""
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError(
|
||||
"names_to_saveables must be a dict mapping string names to "
|
||||
"checkpointable operations. Name is not a string: %s" % name)
|
||||
if isinstance(op, BaseSaverBuilder.SaveableObject):
|
||||
yield op
|
||||
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
|
||||
if isinstance(op, variables.PartitionedVariable):
|
||||
op = list(op)
|
||||
# A set of slices.
|
||||
slice_name = None
|
||||
# pylint: disable=protected-access
|
||||
for variable in op:
|
||||
if not isinstance(variable, variables.Variable):
|
||||
raise ValueError("Slices must all be Variables: %s" % variable)
|
||||
if not variable._save_slice_info:
|
||||
raise ValueError("Slices must all be slices: %s" % variable)
|
||||
if slice_name is None:
|
||||
slice_name = variable._save_slice_info.full_name
|
||||
elif slice_name != variable._save_slice_info.full_name:
|
||||
raise ValueError(
|
||||
"Slices must all be from the same tensor: %s != %s" %
|
||||
(slice_name, variable._save_slice_info.full_name))
|
||||
if variable.op.type in ["Variable", "VariableV2",
|
||||
"AutoReloadVariable"]:
|
||||
yield BaseSaverBuilder.VariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
else:
|
||||
yield BaseSaverBuilder.ResourceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
# pylint: enable=protected-access
|
||||
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
|
||||
op, variables.Variable):
|
||||
# pylint: disable=protected-access
|
||||
for attr, factory in op._gather_saveables_for_checkpoint().items():
|
||||
if attr == checkpointable.VARIABLE_VALUE_KEY:
|
||||
# Keep original name for classes masquerading as variables.
|
||||
full_name = name
|
||||
else:
|
||||
full_name = name + "_" + attr
|
||||
op = (factory(full_name) if callable(factory) else factory)
|
||||
for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
|
||||
yield op
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
# A variable or tensor.
|
||||
if context.executing_eagerly():
|
||||
if not isinstance(op, resource_variable_ops.ResourceVariable):
|
||||
raise ValueError("Can only save/restore ResourceVariable eager "
|
||||
"mode is enabled, type: %s." % type(op))
|
||||
yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
|
||||
else:
|
||||
if isinstance(op, resource_variable_ops.ResourceVariable):
|
||||
variable = op._graph_element # pylint: disable=protected-access
|
||||
else:
|
||||
variable = ops.internal_convert_to_tensor(op, as_ref=True)
|
||||
if not BaseSaverBuilder._IsVariable(variable):
|
||||
raise TypeError("names_to_saveables must be a dict mapping string "
|
||||
"names to Tensors/Variables. Not a variable: %s" %
|
||||
variable)
|
||||
if variable.op.type in ["Variable", "VariableV2",
|
||||
"AutoReloadVariable"]:
|
||||
yield BaseSaverBuilder.VariableSaveable(variable, "", name)
|
||||
else:
|
||||
yield BaseSaverBuilder.ResourceVariableSaveable(
|
||||
variable, "", name)
|
||||
|
||||
def _ValidateAndSliceInputs(self, names_to_saveables):
|
||||
"""Returns the variables and names that will be used for a Saver.
|
||||
|
||||
Args:
|
||||
names_to_saveables: A dict (k, v) where k is the name of an operation and
|
||||
v is an operation to save or a BaseSaverBuilder.Saver.
|
||||
|
||||
Returns:
|
||||
A list of BaseSaverBuilder.SaveableObject objects.
|
||||
|
||||
Raises:
|
||||
TypeError: If any of the keys are not strings or any of the
|
||||
values are not one of Tensor or Variable or a checkpointable operation.
|
||||
ValueError: If the same operation is given in more than one value
|
||||
(this also applies to slices of SlicedVariables).
|
||||
"""
|
||||
if not isinstance(names_to_saveables, dict):
|
||||
names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
|
||||
|
||||
saveables = []
|
||||
seen_ops = set()
|
||||
for name, op in sorted(names_to_saveables.items(),
|
||||
# Avoid comparing ops, sort only by name.
|
||||
key=lambda x: x[0]):
|
||||
for converted_saveable_object in self.SaveableObjectsForOp(op, name):
|
||||
self._AddSaveable(saveables, seen_ops, converted_saveable_object)
|
||||
return saveables
|
||||
|
||||
def _AddSaveable(self, saveables, seen_ops, saveable):
|
||||
"""Adds the saveable to the saveables list.
|
||||
|
||||
Args:
|
||||
saveables: List to append the SaveableObject to.
|
||||
seen_ops: Set of the ops of the saveables already processed. Used to
|
||||
check that each saveable is only saved once.
|
||||
saveable: The saveable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the saveable has already been processed.
|
||||
"""
|
||||
if saveable.op in seen_ops:
|
||||
raise ValueError("The same saveable will be restored with two names: %s" %
|
||||
saveable.name)
|
||||
saveables.append(saveable)
|
||||
seen_ops.add(saveable.op)
|
||||
|
||||
def build(self,
|
||||
names_to_saveables,
|
||||
reshape=False,
|
||||
@ -775,7 +483,8 @@ class BaseSaverBuilder(object):
|
||||
raise ValueError("save and restore operations need to be built together "
|
||||
" when eager execution is not enabled.")
|
||||
|
||||
saveables = self._ValidateAndSliceInputs(names_to_saveables)
|
||||
saveables = saveable_object_util.validate_and_slice_inputs(
|
||||
names_to_saveables)
|
||||
if max_to_keep is None:
|
||||
max_to_keep = 0
|
||||
|
||||
@ -1910,7 +1619,7 @@ def saver_from_object_based_checkpoint(
|
||||
if builder is None:
|
||||
builder = BulkSaverBuilder()
|
||||
|
||||
saveables = builder._ValidateAndSliceInputs(var_list) # pylint: disable=protected-access
|
||||
saveables = saveable_object_util.validate_and_slice_inputs(var_list)
|
||||
current_names = set()
|
||||
for saveable in saveables:
|
||||
for spec in saveable.specs:
|
||||
|
55
tensorflow/python/training/saving/BUILD
Normal file
55
tensorflow/python/training/saving/BUILD
Normal file
@ -0,0 +1,55 @@
|
||||
# Description:
|
||||
# Low-level utilities for reading and writing checkpoints.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "functional_saver",
|
||||
srcs = ["functional_saver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saveable_object",
|
||||
":saveable_object_util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "functional_saver_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"functional_saver_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":functional_saver",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_object",
|
||||
srcs = ["saveable_object.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_object_util",
|
||||
srcs = ["saveable_object_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
101
tensorflow/python/training/saving/functional_saver.py
Normal file
101
tensorflow/python/training/saving/functional_saver.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Saves and restore variables inside traced @tf.functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
|
||||
|
||||
class Saver(object):
|
||||
"""A minimal utility class for saving and restoring checkpoints.
|
||||
|
||||
Note that this is a low-level utility which stores Tensors in the keys
|
||||
specified by `SaveableObject`s. Higher-level utilities for object-based
|
||||
checkpointing are built on top of it.
|
||||
"""
|
||||
|
||||
def __init__(self, saveable_objects):
|
||||
"""Specify a list of `SaveableObject`s to save and restore.
|
||||
|
||||
Args:
|
||||
saveable_objects: A list of `SaveableObject`s.
|
||||
"""
|
||||
saveable_objects = list(saveable_objects)
|
||||
for saveable in saveable_objects:
|
||||
if not isinstance(saveable, saveable_object.SaveableObject):
|
||||
raise ValueError(
|
||||
"Saver expected a list of SaveableObjects, got %s." % (saveable,))
|
||||
self._saveable_objects = saveable_objects
|
||||
|
||||
# TODO(b/120569892): Use tf.function here
|
||||
def save(self, file_prefix):
|
||||
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
||||
|
||||
Args:
|
||||
file_prefix: A string or scalar string Tensor containing the prefix to
|
||||
save under.
|
||||
Returns:
|
||||
A scalar string Tensor containing `file_prefix` with control dependencies
|
||||
on the save ops.
|
||||
"""
|
||||
tensor_names = []
|
||||
tensors = []
|
||||
tensor_slices = []
|
||||
for saveable in self._saveable_objects:
|
||||
for spec in saveable.specs:
|
||||
tensor_names.append(spec.name)
|
||||
tensors.append(spec.tensor)
|
||||
tensor_slices.append(spec.slice_spec)
|
||||
with ops.control_dependencies(
|
||||
[io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)]):
|
||||
return array_ops.identity(file_prefix)
|
||||
|
||||
# TODO(b/120569892): Use tf.function here
|
||||
def restore(self, file_prefix):
|
||||
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
||||
|
||||
Args:
|
||||
file_prefix: A string or scalar string Tensor containing the prefix for
|
||||
files to read from.
|
||||
|
||||
Returns:
|
||||
An operation which restores the `Saver`'s `SaveableObject`s when run, or
|
||||
None if executing eagerly.
|
||||
"""
|
||||
restore_ops = []
|
||||
for saveable in self._saveable_objects:
|
||||
if saveable.device:
|
||||
device = saveable_object_util.set_cpu0(saveable.device)
|
||||
else:
|
||||
device = None
|
||||
with ops.device(device):
|
||||
tensors = []
|
||||
for spec in saveable.specs:
|
||||
tensors.append(
|
||||
io_ops.restore_v2(
|
||||
file_prefix,
|
||||
[spec.name],
|
||||
[spec.slice_spec],
|
||||
[spec.dtype])[0])
|
||||
restore_ops.append(saveable.restore(tensors, restored_shapes=None))
|
||||
return control_flow_ops.group(restore_ops)
|
50
tensorflow/python/training/saving/functional_saver_test.py
Normal file
50
tensorflow/python/training/saving/functional_saver_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for the functional saver."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
|
||||
|
||||
class SaverTest(test.TestCase):
|
||||
|
||||
def test_resource_variable(self):
|
||||
v1 = resource_variable_ops.ResourceVariable(2.)
|
||||
saver = functional_saver.Saver(
|
||||
saveable_object_util.saveable_objects_for_op(v1, "x"))
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
save_path = saver.save(constant_op.constant(prefix))
|
||||
v1.assign(1.)
|
||||
saver.restore(save_path)
|
||||
self.assertEqual(2., self.evaluate(v1))
|
||||
|
||||
v2 = resource_variable_ops.ResourceVariable(3.)
|
||||
second_saver = functional_saver.Saver(
|
||||
saveable_object_util.saveable_objects_for_op(v2, "x"))
|
||||
second_saver.restore(save_path)
|
||||
self.assertEqual(2., self.evaluate(v2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
340
tensorflow/python/training/saving/saveable_object_util.py
Normal file
340
tensorflow/python/training/saving/saveable_object_util.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for working with and creating SaveableObjects."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
|
||||
|
||||
# Op names which identify variable reads which should be saved.
|
||||
_VARIABLE_OPS = set(["Variable",
|
||||
"VariableV2",
|
||||
"AutoReloadVariable",
|
||||
"VarHandleOp",
|
||||
"ReadVariableOp"])
|
||||
|
||||
|
||||
def set_cpu0(device_string):
|
||||
"""Creates a new device string based on `device_string` but using /CPU:0.
|
||||
|
||||
If the device is already on /CPU:0, this is a no-op.
|
||||
|
||||
Args:
|
||||
device_string: A device string.
|
||||
|
||||
Returns:
|
||||
A device string.
|
||||
"""
|
||||
parsed_device = pydev.DeviceSpec.from_string(device_string)
|
||||
parsed_device.device_type = "CPU"
|
||||
parsed_device.device_index = 0
|
||||
return parsed_device.to_string()
|
||||
|
||||
|
||||
class ReferenceVariableSaveable(saveable_object.SaveableObject):
|
||||
"""SaveableObject implementation that handles reference variables."""
|
||||
|
||||
def __init__(self, var, slice_spec, name):
|
||||
spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
|
||||
super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
restored_tensor = restored_tensors[0]
|
||||
if restored_shapes is not None:
|
||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||
return state_ops.assign(
|
||||
self.op,
|
||||
restored_tensor,
|
||||
validate_shape=restored_shapes is None and
|
||||
self.op.get_shape().is_fully_defined())
|
||||
|
||||
|
||||
class ResourceVariableSaveable(saveable_object.SaveableObject):
|
||||
"""SaveableObject implementation that handles ResourceVariables."""
|
||||
|
||||
def __init__(self, var, slice_spec, name):
|
||||
self._var_device = var.device
|
||||
self._var_shape = var.shape
|
||||
if isinstance(var, ops.Tensor):
|
||||
self.handle_op = var.op.inputs[0]
|
||||
tensor = var
|
||||
elif isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
|
||||
def _read_variable_closure(v):
|
||||
def f():
|
||||
with ops.device(v.device):
|
||||
x = v.read_value()
|
||||
# To allow variables placed on non-CPU devices to be checkpointed,
|
||||
# we copy them to CPU on the same machine first.
|
||||
with ops.device("/device:CPU:0"):
|
||||
return array_ops.identity(x)
|
||||
return f
|
||||
|
||||
self.handle_op = var.handle
|
||||
tensor = _read_variable_closure(var)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Saveable is neither a resource variable nor a read operation."
|
||||
" Got: %s" % repr(var))
|
||||
spec = saveable_object.SaveSpec(tensor, slice_spec, name,
|
||||
dtype=var.dtype)
|
||||
super(ResourceVariableSaveable, self).__init__(var, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
restored_tensor = restored_tensors[0]
|
||||
if restored_shapes is not None:
|
||||
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
|
||||
# Copy the restored tensor to the variable's device.
|
||||
with ops.device(self._var_device):
|
||||
restored_tensor = array_ops.identity(restored_tensor)
|
||||
return resource_variable_ops.shape_safe_assign_variable_handle(
|
||||
self.handle_op, self._var_shape, restored_tensor)
|
||||
|
||||
|
||||
def _tensor_comes_from_variable(v):
|
||||
return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
|
||||
|
||||
|
||||
def saveable_objects_for_op(op, name):
|
||||
"""Create `SaveableObject`s from an operation.
|
||||
|
||||
Args:
|
||||
op: A variable, operation, or SaveableObject to coerce into a
|
||||
SaveableObject.
|
||||
name: A string name for the SaveableObject.
|
||||
|
||||
Yields:
|
||||
`SaveableObject`s which together save/restore `op`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `name` is not a string.
|
||||
ValueError: For operations with no known conversion to SaveableObject.
|
||||
"""
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError(
|
||||
"names_to_saveables must be a dict mapping string names to "
|
||||
"checkpointable operations. Name is not a string: %s" % name)
|
||||
if isinstance(op, saveable_object.SaveableObject):
|
||||
yield op
|
||||
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
|
||||
if isinstance(op, variables.PartitionedVariable):
|
||||
op = list(op)
|
||||
# A set of slices.
|
||||
slice_name = None
|
||||
# pylint: disable=protected-access
|
||||
for variable in op:
|
||||
if not isinstance(variable, variables.Variable):
|
||||
raise ValueError("Slices must all be Variables: %s" % variable)
|
||||
if not variable._save_slice_info:
|
||||
raise ValueError("Slices must all be slices: %s" % variable)
|
||||
if slice_name is None:
|
||||
slice_name = variable._save_slice_info.full_name
|
||||
elif slice_name != variable._save_slice_info.full_name:
|
||||
raise ValueError(
|
||||
"Slices must all be from the same tensor: %s != %s" %
|
||||
(slice_name, variable._save_slice_info.full_name))
|
||||
if variable.op.type in ["Variable", "VariableV2",
|
||||
"AutoReloadVariable"]:
|
||||
yield ReferenceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
else:
|
||||
yield ResourceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
# pylint: enable=protected-access
|
||||
elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
|
||||
op, variables.Variable):
|
||||
# pylint: disable=protected-access
|
||||
for attr, factory in op._gather_saveables_for_checkpoint().items():
|
||||
if attr == checkpointable.VARIABLE_VALUE_KEY:
|
||||
# Keep original name for classes masquerading as variables.
|
||||
full_name = name
|
||||
else:
|
||||
full_name = name + "_" + attr
|
||||
op = (factory(full_name) if callable(factory) else factory)
|
||||
for op in saveable_objects_for_op(op, op.name):
|
||||
yield op
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
# A variable or tensor.
|
||||
if isinstance(op, resource_variable_ops.ResourceVariable):
|
||||
# pylint: disable=protected-access
|
||||
if op._in_graph_mode:
|
||||
variable = op._graph_element
|
||||
else:
|
||||
variable = op
|
||||
# pylint: enable=protected-access
|
||||
yield ResourceVariableSaveable(variable, "", name)
|
||||
else:
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Can only save/restore ResourceVariables when "
|
||||
"executing eagerly, got type: %s." % type(op))
|
||||
|
||||
variable = ops.internal_convert_to_tensor(op, as_ref=True)
|
||||
if not _tensor_comes_from_variable(variable):
|
||||
raise TypeError("names_to_saveables must be a dict mapping string "
|
||||
"names to Tensors/Variables. Not a variable: %s" %
|
||||
variable)
|
||||
if variable.op.type in ["Variable", "VariableV2",
|
||||
"AutoReloadVariable"]:
|
||||
yield ReferenceVariableSaveable(variable, "", name)
|
||||
else:
|
||||
yield ResourceVariableSaveable(
|
||||
variable, "", name)
|
||||
|
||||
|
||||
def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
"""Create a dictionary of names to operation lists.
|
||||
|
||||
Args:
|
||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||
with no slice info into Tensors.
|
||||
|
||||
Returns:
|
||||
A dictionary of names to the operations that must be saved under
|
||||
that name. Variables with save_slice_info are grouped together under the
|
||||
same key in no particular order.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of op_list or its elements is not supported.
|
||||
ValueError: If at least two saveables share the same name.
|
||||
"""
|
||||
if not isinstance(op_list, (list, tuple, set)):
|
||||
raise TypeError("Variables to save should be passed in a dict or a "
|
||||
"list: %s" % op_list)
|
||||
# When ResourceVariables are converted to Tensors, read ops are added to the
|
||||
# graph. Sorting the op_list ensures that the resulting graph is always
|
||||
# constructed in a deterministic way:
|
||||
op_list = sorted(op_list, key=lambda x: x.name)
|
||||
names_to_saveables = {}
|
||||
# pylint: disable=protected-access
|
||||
for var in op_list:
|
||||
if isinstance(var, saveable_object.SaveableObject):
|
||||
names_to_saveables[var.name] = var
|
||||
elif isinstance(var, variables.PartitionedVariable):
|
||||
if var.name in names_to_saveables:
|
||||
raise ValueError("At least two variables have the same name: %s" %
|
||||
var.name)
|
||||
names_to_saveables[var.name] = var
|
||||
elif isinstance(var, variables.Variable) and var._save_slice_info:
|
||||
name = var._save_slice_info.full_name
|
||||
if name in names_to_saveables:
|
||||
if not isinstance(names_to_saveables[name], list):
|
||||
raise ValueError("Mixing slices and non-slices with the same name: "
|
||||
"%s" % name)
|
||||
names_to_saveables[name].append(var)
|
||||
else:
|
||||
names_to_saveables[name] = [var]
|
||||
elif (isinstance(var, checkpointable.CheckpointableBase)
|
||||
and not isinstance(var, variables.Variable)):
|
||||
checkpointable_saveables = [
|
||||
(factory() if callable(factory) else factory)
|
||||
for factory in var._gather_saveables_for_checkpoint().values()]
|
||||
names_to_saveables.update(
|
||||
op_list_to_dict(checkpointable_saveables))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
raise ValueError(
|
||||
"Can only save/restore ResourceVariables when eager execution "
|
||||
"is enabled, type: %s." % type(var))
|
||||
set_var = names_to_saveables.setdefault(var._shared_name, var)
|
||||
if set_var is not var:
|
||||
raise ValueError(
|
||||
("Two different ResourceVariable objects with the same "
|
||||
"shared_name '%s' were passed to the Saver. This likely means "
|
||||
"that they were created in different Graphs or isolation "
|
||||
"contexts, and may not be checkpointed together.") %
|
||||
(var._shared_name,))
|
||||
else:
|
||||
if convert_variable_to_tensor:
|
||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
var = var._graph_element # pylint: disable=protected-access
|
||||
else:
|
||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||
if not _tensor_comes_from_variable(var):
|
||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||
if var.op.type == "ReadVariableOp":
|
||||
name = var.op.inputs[0].op.name
|
||||
else:
|
||||
name = var.op.name
|
||||
if name in names_to_saveables:
|
||||
raise ValueError("At least two variables have the same name: %s" %
|
||||
name)
|
||||
names_to_saveables[name] = var
|
||||
|
||||
# pylint: enable=protected-access
|
||||
return names_to_saveables
|
||||
|
||||
|
||||
def _add_saveable(saveables, seen_ops, saveable):
|
||||
"""Adds the saveable to the saveables list.
|
||||
|
||||
Args:
|
||||
saveables: List to append the SaveableObject to.
|
||||
seen_ops: Set of the ops of the saveables already processed. Used to
|
||||
check that each saveable is only saved once.
|
||||
saveable: The saveable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the saveable has already been processed.
|
||||
"""
|
||||
if saveable.op in seen_ops:
|
||||
raise ValueError("The same saveable will be restored with two names: %s" %
|
||||
saveable.name)
|
||||
saveables.append(saveable)
|
||||
seen_ops.add(saveable.op)
|
||||
|
||||
|
||||
def validate_and_slice_inputs(names_to_saveables):
|
||||
"""Returns the variables and names that will be used for a Saver.
|
||||
|
||||
Args:
|
||||
names_to_saveables: A dict (k, v) where k is the name of an operation and
|
||||
v is an operation to save or a BaseSaverBuilder.Saver.
|
||||
|
||||
Returns:
|
||||
A list of SaveableObjects.
|
||||
|
||||
Raises:
|
||||
TypeError: If any of the keys are not strings or any of the
|
||||
values are not one of Tensor or Variable or a checkpointable operation.
|
||||
ValueError: If the same operation is given in more than one value
|
||||
(this also applies to slices of SlicedVariables).
|
||||
"""
|
||||
if not isinstance(names_to_saveables, dict):
|
||||
names_to_saveables = op_list_to_dict(names_to_saveables)
|
||||
|
||||
saveables = []
|
||||
seen_ops = set()
|
||||
for name, op in sorted(names_to_saveables.items(),
|
||||
# Avoid comparing ops, sort only by name.
|
||||
key=lambda x: x[0]):
|
||||
for converted_saveable_object in saveable_objects_for_op(op, name):
|
||||
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
||||
return saveables
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_ops
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ def _infer_var_name(var):
|
||||
Returns:
|
||||
Name of the `var`
|
||||
"""
|
||||
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
|
||||
name_to_var_dict = saveable_object_util.op_list_to_dict(var)
|
||||
if len(name_to_var_dict) > 1:
|
||||
raise TypeError("`var` = %s passed as arg violates the constraints. "
|
||||
"name_to_var_dict = %s" % (var, name_to_var_dict))
|
||||
|
Loading…
Reference in New Issue
Block a user