From 66ca3cd10df0bf9bb6586bf0a09ac5c5ed0a25fb Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 7 Dec 2018 12:40:49 -0800 Subject: [PATCH] Add a functional saver, use it for object-based checkpointing Pulls some utilities out of saver.py which are necessary to actually use it. The functional saver takes only SaveableObjects, so these are utilities for taking a list of whatever users pass in and converting them to those. One other code move for object-based checkpointing to avoid circular imports. Applications which need a SaverDef still use the old Saver. Serialization to SaverDef will be added to this saver in a followup. Does not actually wrap the new Saver's methods in @tf.function yet, since there are memory issues which need to be fixed first. PiperOrigin-RevId: 224561069 --- .../training/elastic_average_optimizer.py | 5 +- .../training/moving_average_optimizer.py | 3 +- tensorflow/python/BUILD | 11 +- .../python/training/checkpoint_utils.py | 6 +- .../python/training/checkpointable/BUILD | 6 +- .../python/training/checkpointable/base.py | 40 +-- .../python/training/checkpointable/util.py | 151 +++++--- tensorflow/python/training/saver.py | 331 ++--------------- tensorflow/python/training/saving/BUILD | 55 +++ .../training/saving/functional_saver.py | 101 ++++++ .../training/saving/functional_saver_test.py | 50 +++ .../training/{ => saving}/saveable_object.py | 0 .../training/saving/saveable_object_util.py | 340 ++++++++++++++++++ .../python/training/warm_starting_util.py | 4 +- 14 files changed, 687 insertions(+), 416 deletions(-) create mode 100644 tensorflow/python/training/saving/BUILD create mode 100644 tensorflow/python/training/saving/functional_saver.py create mode 100644 tensorflow/python/training/saving/functional_saver_test.py rename tensorflow/python/training/{ => saving}/saveable_object.py (100%) create mode 100644 tensorflow/python/training/saving/saveable_object_util.py diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e..fa1a7aaff0a 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9d..bf3e5c51f78 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cc36f1fc0e4..0a3ee65bc48 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3515,13 +3515,13 @@ py_library( exclude = [ "**/*test*", "training/checkpointable/**/*.py", + "training/saving/**/*.py", # The following targets have their own build rules (same name as the # file): "training/basic_session_run_hooks.py", "training/checkpoint_management.py", "training/distribute.py", "training/distribution_strategy_context.py", - "training/saveable_object.py", "training/saver.py", "training/session_run_hook.py", "training/training_util.py", @@ -3596,12 +3596,6 @@ py_library( ], ) -py_library( - name = "saveable_object", - srcs = ["training/saveable_object.py"], - srcs_version = "PY2AND3", -) - py_library( name = "checkpoint_management", srcs = ["training/checkpoint_management.py"], @@ -3655,7 +3649,6 @@ py_library( ":platform", ":pywrap_tensorflow", ":resource_variable_ops", - ":saveable_object", ":session", ":state_ops", ":string_ops", @@ -3665,6 +3658,8 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/training/saving:saveable_object_util", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 99b1f4c0d7a..74b46179e75 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -30,7 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.util.tf_export import tf_export @@ -311,10 +311,10 @@ def _set_checkpoint_initializer(variable, restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] - names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable]) + names_to_saveables = saveable_object_util.op_list_to_dict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): - for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name): + for s in saveable_object_util.saveable_objects_for_op(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 4ab5593d4f8..26a0ac35b76 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -25,9 +25,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops_gen", "//tensorflow/python:platform", - "//tensorflow/python:saveable_object", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/training/saving:saveable_object", ], ) @@ -114,7 +114,6 @@ py_library( "//tensorflow/python:init_ops", "//tensorflow/python:io_ops_gen", "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:saveable_object", "//tensorflow/python:saver", "//tensorflow/python:session", "//tensorflow/python:tensor_shape", @@ -123,6 +122,9 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + "//tensorflow/python/training/saving:functional_saver", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/training/saving:saveable_object_util", ], ) diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index 095a90ddd4f..3cd1c6f9c8b 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -25,7 +25,6 @@ import weakref import six -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,7 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saveable_object +from tensorflow.python.training.saving import saveable_object from tensorflow.python.util import nest from tensorflow.python.util import serialization from tensorflow.python.util import tf_decorator @@ -374,41 +373,10 @@ class _CheckpointPosition(object): eagerly. """ (restore_ops, - named_saveables, + tensor_saveables, python_saveables) = self._gather_ops_or_named_saveables() - - # Eagerly run restorations for Python state. - reader = pywrap_tensorflow.NewCheckpointReader( - self._checkpoint.save_path_string) - for saveable in python_saveables: - spec_names = [spec.name for spec in saveable.specs] - saveable.python_restore( - [reader.get_tensor(name) for name in spec_names]) - - # If we have new SaveableObjects, extract and cache restore ops. - if named_saveables: - validated_saveables = ( - self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access - validated_names = set(saveable.name for saveable in validated_saveables) - if set(named_saveables.keys()) != validated_names: - raise AssertionError( - ("Saveable keys changed when validating. Got back %s, was " - "expecting %s") % (named_saveables.keys(), validated_names)) - all_tensors = self._checkpoint.builder.bulk_restore( - filename_tensor=self._checkpoint.save_path_tensor, - saveables=validated_saveables, preferred_shard=-1, - restore_sequentially=False) - saveable_index = 0 - for saveable in validated_saveables: - num_specs = len(saveable.specs) - saveable_tensors = all_tensors[ - saveable_index:saveable_index + num_specs] - saveable_index += num_specs - restore_op = saveable.restore(saveable_tensors, restored_shapes=None) - if not context.executing_eagerly(): - assert saveable.name not in self._checkpoint.restore_ops_by_name - self._checkpoint.restore_ops_by_name[saveable.name] = restore_op - restore_ops.append(restore_op) + restore_ops.extend(self._checkpoint.restore_saveables( + tensor_saveables, python_saveables)) return restore_ops @property diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index d183fbdcf93..a54f41a54fa 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -40,11 +40,14 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import checkpoint_management from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saveable_object as saveable_object_lib -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import saver as v1_saver_lib from tensorflow.python.training.checkpointable import base from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.saving import functional_saver +from tensorflow.python.training.saving import saveable_object as saveable_object_lib +from tensorflow.python.training.saving import saveable_object_util +from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -89,7 +92,6 @@ class _CheckpointRestoreCoordinator(object): referenced every restore (e.g. for Python state); otherwise they would create their own ops every restore. """ - self.builder = saver_lib.BulkSaverBuilder() self.object_graph_proto = object_graph_proto self.restore_uid = ops.uid() # Maps from objects to lists of attributes which were in the checkpoint but @@ -144,6 +146,57 @@ class _CheckpointRestoreCoordinator(object): if self.new_restore_ops_callback: self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable + def restore_saveables(self, tensor_saveables, python_saveables): + """Run or build restore operations for SaveableObjects. + + Args: + tensor_saveables: `SaveableObject`s which correspond to Tensors. + python_saveables: `PythonStateSaveable`s which correspond to Python + values. + + Returns: + When graph building, a list of restore operations, either cached or newly + created, to restore `tensor_saveables`. + """ + restore_ops = [] + # Eagerly run restorations for Python state. + reader = pywrap_tensorflow.NewCheckpointReader( + self.save_path_string) + for saveable in python_saveables: + spec_names = [spec.name for spec in saveable.specs] + saveable.python_restore( + [reader.get_tensor(name) for name in spec_names]) + + # If we have new SaveableObjects, extract and cache restore ops. + if tensor_saveables: + validated_saveables = saveable_object_util.validate_and_slice_inputs( + tensor_saveables) + validated_names = set(saveable.name for saveable in validated_saveables) + if set(tensor_saveables.keys()) != validated_names: + raise AssertionError( + ("Saveable keys changed when validating. Got back %s, was " + "expecting %s") % (tensor_saveables.keys(), validated_names)) + for saveable in validated_saveables: + if saveable.device: + device = saveable_object_util.set_cpu0(saveable.device) + else: + device = None + with ops.device(device): + tensors = [] + for spec in saveable.specs: + tensors.append( + io_ops.restore_v2( + self.save_path_tensor, + [spec.name], + [spec.slice_spec], + [spec.dtype])[0]) + restore_op = saveable.restore(tensors, restored_shapes=None) + if not context.executing_eagerly(): + assert saveable.name not in self.restore_ops_by_name + self.restore_ops_by_name[saveable.name] = restore_op + restore_ops.append(restore_op) + return restore_ops + class _NameBasedRestoreCoordinator(object): """Keeps the status of a name-based checkpoint restore.""" @@ -183,11 +236,11 @@ class _NameBasedRestoreCoordinator(object): continue else: saveable = saveable_factory - names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict( + names_to_saveables = saveable_object_util.op_list_to_dict( [saveable], convert_variable_to_tensor=False) for name, op in names_to_saveables.items(): - for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp( + for saveable_object in saveable_object_util.saveable_objects_for_op( op=op, name=name): yield saveable_object @@ -606,10 +659,10 @@ def _add_attributes_to_object_graph( # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() - saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp( + saveables = tuple(saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name @@ -1226,7 +1279,7 @@ class NameBasedSaverStatus(_LoadStatus): session = ops.get_default_session() with ops.device("/cpu:0"): saveables = self._gather_saveable_objects() - saver_lib.Saver(saveables).restore( + v1_saver_lib.Saver(saveables).restore( sess=session, save_path=self._checkpoint.save_path) def initialize_or_restore(self, session=None): @@ -1251,18 +1304,6 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface): fetches=fetches, feed_dict=feed_dict, **kwargs) -def _copy_saver_with_new_var_list(old_saver, new_var_list): - """Copy a `tf.train.Saver`'s state to a new Saver with different variables.""" - new_saver = saver_lib.Saver(var_list=new_var_list, max_to_keep=None) - # TODO(allenl): Move to copying functionality to Saver? - # pylint: disable=protected-access - new_saver._last_checkpoints = old_saver._last_checkpoints - new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted - new_saver._next_checkpoint_time = old_saver._next_checkpoint_time - # pylint: enable=protected-access - return new_saver - - class CheckpointableSaver(object): """Saves and restores a `Checkpointable` object and its dependencies. @@ -1301,7 +1342,8 @@ class CheckpointableSaver(object): # Op caching for save self._object_graph_feed_tensor = None self._last_save_object_graph = None - self._last_save_saver = None + self._file_prefix_feed_tensor = None + self._cached_save_operation = None # Op caching for restore, shared between _CheckpointRestoreCoordinators self._restore_op_cache = {} @@ -1368,13 +1410,16 @@ class CheckpointableSaver(object): base.NoRestoreSaveable( tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) - # TODO(allenl, haoliang): Swap in a function-based saver here. - return saver_lib.Saver( + # TODO(allenl): Swap in a function-based saver here once it can serialize + # to a SaverDef. + return v1_saver_lib.Saver( var_list=named_saveable_objects, max_to_keep=None) - def _prepare_save(self, - object_graph_tensor=None, - saveable_object_cache=None): + def _save_cached_when_graph_building( + self, + file_prefix, + object_graph_tensor=None, + saveable_object_cache=None): """Create or retrieve save ops. When graph building, `saveable_object_cache` will typically be non-`None`, @@ -1383,15 +1428,17 @@ class CheckpointableSaver(object): unnecessarily re-creating save ops. Args: + file_prefix: The prefix for saved checkpoint files. object_graph_tensor: A `Tensor` to which the current object graph will be fed. saveable_object_cache: A dictionary; if specified, used to cache `SaveableObject`s. Returns: - A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s - to feed when running save ops. The feed dict contains the current object - graph and any Python state to be saved in the checkpoint. + A two-element tuple with a filename tensor and a feed_dict of tensors to + feed when running it (if graph building). The feed dict contains the + current object graph and any Python state to be saved in the + checkpoint. When executing eagerly only the first argument is meaningful. """ (named_saveable_objects, graph_proto, feed_additions) = self._gather_saveables( @@ -1403,15 +1450,11 @@ class CheckpointableSaver(object): # constructors. That means the Saver needs to be copied with a new # var_list. or context.executing_eagerly()): - if self._last_save_object_graph is not None: - self._last_save_saver = _copy_saver_with_new_var_list( - old_saver=self._last_save_saver, - new_var_list=named_saveable_objects) - else: - self._last_save_saver = saver_lib.Saver( - var_list=named_saveable_objects, max_to_keep=None) + saver = functional_saver.Saver(named_saveable_objects) + with ops.device("/cpu:0"): + self._cached_save_operation = saver.save(file_prefix) self._last_save_object_graph = graph_proto - return self._last_save_saver, feed_additions + return self._cached_save_operation, feed_additions def save(self, file_prefix, checkpoint_number=None, session=None): """Save a training checkpoint. @@ -1435,36 +1478,42 @@ class CheckpointableSaver(object): Returns: The full path to the checkpoint. """ - feed_additions = {} + feed_dict = {} graph_building = not context.executing_eagerly() + if checkpoint_number: + file_prefix = "%s-%d" % (file_prefix, checkpoint_number) if graph_building: if self._object_graph_feed_tensor is None: with ops.device("/cpu:0"): self._object_graph_feed_tensor = constant_op.constant( "", dtype=dtypes.string) + self._file_prefix_feed_tensor = constant_op.constant( + "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor + file_prefix_tensor = self._file_prefix_feed_tensor + feed_dict[file_prefix_tensor] = file_prefix else: + with ops.device("/cpu:0"): + file_prefix_tensor = constant_op.constant( + file_prefix, dtype=dtypes.string) object_graph_tensor = None - saver, new_feed_additions = self._prepare_save( + file_io.recursive_create_dir(os.path.dirname(file_prefix)) + save_path, new_feed_additions = self._save_cached_when_graph_building( + file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor, saveable_object_cache=self._saveable_object_cache) if new_feed_additions: - feed_additions.update(new_feed_additions) + feed_dict.update(new_feed_additions) if not graph_building: session = None elif session is None: session = ops.get_default_session() - file_io.recursive_create_dir(os.path.dirname(file_prefix)) - with ops.device("/cpu:0"): - save_path = saver.save( - sess=_SessionWithFeedDictAdditions( - session=session, feed_additions=feed_additions), - save_path=file_prefix, - write_meta_graph=False, - write_state=False, - global_step=checkpoint_number) + if session: + save_path = session.run(save_path, feed_dict=feed_dict) + else: + save_path = save_path.numpy() return save_path def restore(self, save_path): @@ -1753,9 +1802,9 @@ class Checkpoint(tracking.Checkpointable): Returns: The full path to the checkpoint (i.e. `file_prefix`). """ - return self._saver.save( + return compat.as_str(self._saver.save( file_prefix=file_prefix, - session=session) + session=session)) @property def save_counter(self): diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 4cd09f8a1d5..04a72164849 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -14,7 +14,11 @@ # ============================================================================== # pylint: disable=invalid-name -"""Save and restore variables.""" +"""Save and restore variables. + +Symbols in this file are deprecated. See replacements in +tensorflow/python/training/checkpointable and tensorflow/python/training/saving. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -25,7 +29,6 @@ import time import uuid import numpy as np -import six from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -42,16 +45,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import saveable_object from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -67,31 +69,6 @@ get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes remove_checkpoint = checkpoint_management.remove_checkpoint -# Op names which identify variable reads which should be saved. -_VARIABLE_OPS = set(["Variable", - "VariableV2", - "AutoReloadVariable", - "VarHandleOp", - "ReadVariableOp"]) - - -def _set_cpu0(device_string): - """Creates a new device string based on `device_string` but using /CPU:0. - - If the device is already on /CPU:0, this is a no-op. - - Args: - device_string: A device string. - - Returns: - A device string. - """ - parsed_device = pydev.DeviceSpec.from_string(device_string) - parsed_device.device_type = "CPU" - parsed_device.device_index = 0 - return parsed_device.to_string() - - class BaseSaverBuilder(object): """Base class for Savers. @@ -101,64 +78,9 @@ class BaseSaverBuilder(object): SaveSpec = saveable_object.SaveSpec SaveableObject = saveable_object.SaveableObject - class VariableSaveable(SaveableObject): - """SaveableObject implementation that handles Variables.""" - - def __init__(self, var, slice_spec, name): - spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype) - super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - restored_tensor = restored_tensors[0] - if restored_shapes is not None: - restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) - return state_ops.assign( - self.op, - restored_tensor, - validate_shape=restored_shapes is None and - self.op.get_shape().is_fully_defined()) - - class ResourceVariableSaveable(SaveableObject): - """SaveableObject implementation that handles ResourceVariables.""" - - def __init__(self, var, slice_spec, name): - self._var_device = var.device - self._var_shape = var.shape - if isinstance(var, ops.Tensor): - self.handle_op = var.op.inputs[0] - tensor = var - elif isinstance(var, resource_variable_ops.ResourceVariable): - - def _read_variable_closure(v): - def f(): - with ops.device(v.device): - x = v.read_value() - # To allow variables placed on non-CPU devices to be checkpointed, - # we copy them to CPU on the same machine first. - with ops.device("/device:CPU:0"): - return array_ops.identity(x) - return f - - self.handle_op = var.handle - tensor = _read_variable_closure(var) - else: - raise ValueError( - "Saveable is neither a resource variable nor a read operation." - " Got: %s" % repr(var)) - spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name, - dtype=var.dtype) - super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__( - var, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - restored_tensor = restored_tensors[0] - if restored_shapes is not None: - restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) - # Copy the restored tensor to the variable's device. - with ops.device(self._var_device): - restored_tensor = array_ops.identity(restored_tensor) - return resource_variable_ops.shape_safe_assign_variable_handle( - self.handle_op, self._var_shape, restored_tensor) + # Aliases for code which was moved but still has lots of users. + VariableSaveable = saveable_object_util.ReferenceVariableSaveable + ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable def __init__(self, write_version=saver_pb2.SaverDef.V2): self._write_version = write_version @@ -224,7 +146,11 @@ class BaseSaverBuilder(object): del restore_sequentially all_tensors = [] for saveable in saveables: - with ops.device(_set_cpu0(saveable.device) if saveable.device else None): + if saveable.device: + device = saveable_object_util.set_cpu0(saveable.device) + else: + device = None + with ops.device(device): all_tensors.extend( self.restore_op(filename_tensor, saveable, preferred_shard)) return all_tensors @@ -336,7 +262,7 @@ class BaseSaverBuilder(object): last_device = None for shard, (device, saveables) in enumerate(per_device): last_device = device - with ops.device(_set_cpu0(device)): + with ops.device(saveable_object_util.set_cpu0(device)): sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(sharded_filename) @@ -344,7 +270,7 @@ class BaseSaverBuilder(object): with ops.control_dependencies([x.op for x in sharded_saves]): # Co-locates the merge step with the last device. - with ops.device(_set_cpu0(last_device)): + with ops.device(saveable_object_util.set_cpu0(last_device)): # V2 format write path consists of a metadata merge step. Once merged, # attempts to delete the temporary directory, "_temp". merge_step = gen_io_ops.merge_v2_checkpoints( @@ -459,10 +385,6 @@ class BaseSaverBuilder(object): name="restore_shard")) return control_flow_ops.group(*sharded_restores, name="restore_all") - @staticmethod - def _IsVariable(v): - return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS - def _GroupByDevices(self, saveables): """Group Variable tensor slices per device. @@ -490,220 +412,6 @@ class BaseSaverBuilder(object): per_device[canonical_device.pop()].append(saveable) return sorted(per_device.items(), key=lambda t: t[0]) - @staticmethod - def OpListToDict(op_list, convert_variable_to_tensor=True): - """Create a dictionary of names to operation lists. - - Args: - op_list: A list, tuple, or set of Variables or SaveableObjects. - convert_variable_to_tensor: Whether or not to convert single Variables - with no slice info into Tensors. - - Returns: - A dictionary of names to the operations that must be saved under - that name. Variables with save_slice_info are grouped together under the - same key in no particular order. - - Raises: - TypeError: If the type of op_list or its elements is not supported. - ValueError: If at least two saveables share the same name. - """ - if not isinstance(op_list, (list, tuple, set)): - raise TypeError("Variables to save should be passed in a dict or a " - "list: %s" % op_list) - # When ResourceVariables are converted to Tensors, read ops are added to the - # graph. Sorting the op_list ensures that the resulting graph is always - # constructed in a deterministic way: - op_list = sorted(op_list, key=lambda x: x.name) - names_to_saveables = {} - # pylint: disable=protected-access - for var in op_list: - if isinstance(var, BaseSaverBuilder.SaveableObject): - names_to_saveables[var.name] = var - elif isinstance(var, variables.PartitionedVariable): - if var.name in names_to_saveables: - raise ValueError("At least two variables have the same name: %s" % - var.name) - names_to_saveables[var.name] = var - elif isinstance(var, variables.Variable) and var._save_slice_info: - name = var._save_slice_info.full_name - if name in names_to_saveables: - if not isinstance(names_to_saveables[name], list): - raise ValueError("Mixing slices and non-slices with the same name: " - "%s" % name) - names_to_saveables[name].append(var) - else: - names_to_saveables[name] = [var] - elif (isinstance(var, checkpointable.CheckpointableBase) - and not isinstance(var, variables.Variable)): - checkpointable_saveables = [ - (factory() if callable(factory) else factory) - for factory in var._gather_saveables_for_checkpoint().values()] - names_to_saveables.update( - BaseSaverBuilder.OpListToDict(checkpointable_saveables)) - else: - if context.executing_eagerly(): - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError( - "Can only save/restore ResourceVariables when eager execution " - "is enabled, type: %s." % type(var)) - set_var = names_to_saveables.setdefault(var._shared_name, var) - if set_var is not var: - raise ValueError( - ("Two different ResourceVariable objects with the same " - "shared_name '%s' were passed to the Saver. This likely means " - "that they were created in different Graphs or isolation " - "contexts, and may not be checkpointed together.") % - (var._shared_name,)) - else: - if convert_variable_to_tensor: - if isinstance(var, resource_variable_ops.ResourceVariable): - var = var._graph_element # pylint: disable=protected-access - else: - var = ops.internal_convert_to_tensor(var, as_ref=True) - if not BaseSaverBuilder._IsVariable(var): - raise TypeError("Variable to save is not a Variable: %s" % var) - if var.op.type == "ReadVariableOp": - name = var.op.inputs[0].op.name - else: - name = var.op.name - if name in names_to_saveables: - raise ValueError("At least two variables have the same name: %s" % - name) - names_to_saveables[name] = var - - # pylint: enable=protected-access - return names_to_saveables - - @staticmethod - def SaveableObjectsForOp(op, name): - """Create `SaveableObject`s from an operation. - - Args: - op: A variable, operation, or SaveableObject to coerce into a - SaveableObject. - name: A string name for the SaveableObject. - - Yields: - `SaveableObject`s which together save/restore `op`. - - Raises: - TypeError: If `name` is not a string. - ValueError: For operations with no known conversion to SaveableObject. - """ - if not isinstance(name, six.string_types): - raise TypeError( - "names_to_saveables must be a dict mapping string names to " - "checkpointable operations. Name is not a string: %s" % name) - if isinstance(op, BaseSaverBuilder.SaveableObject): - yield op - elif isinstance(op, (list, tuple, variables.PartitionedVariable)): - if isinstance(op, variables.PartitionedVariable): - op = list(op) - # A set of slices. - slice_name = None - # pylint: disable=protected-access - for variable in op: - if not isinstance(variable, variables.Variable): - raise ValueError("Slices must all be Variables: %s" % variable) - if not variable._save_slice_info: - raise ValueError("Slices must all be slices: %s" % variable) - if slice_name is None: - slice_name = variable._save_slice_info.full_name - elif slice_name != variable._save_slice_info.full_name: - raise ValueError( - "Slices must all be from the same tensor: %s != %s" % - (slice_name, variable._save_slice_info.full_name)) - if variable.op.type in ["Variable", "VariableV2", - "AutoReloadVariable"]: - yield BaseSaverBuilder.VariableSaveable( - variable, variable._save_slice_info.spec, name) - else: - yield BaseSaverBuilder.ResourceVariableSaveable( - variable, variable._save_slice_info.spec, name) - # pylint: enable=protected-access - elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance( - op, variables.Variable): - # pylint: disable=protected-access - for attr, factory in op._gather_saveables_for_checkpoint().items(): - if attr == checkpointable.VARIABLE_VALUE_KEY: - # Keep original name for classes masquerading as variables. - full_name = name - else: - full_name = name + "_" + attr - op = (factory(full_name) if callable(factory) else factory) - for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name): - yield op - # pylint: enable=protected-access - else: - # A variable or tensor. - if context.executing_eagerly(): - if not isinstance(op, resource_variable_ops.ResourceVariable): - raise ValueError("Can only save/restore ResourceVariable eager " - "mode is enabled, type: %s." % type(op)) - yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name) - else: - if isinstance(op, resource_variable_ops.ResourceVariable): - variable = op._graph_element # pylint: disable=protected-access - else: - variable = ops.internal_convert_to_tensor(op, as_ref=True) - if not BaseSaverBuilder._IsVariable(variable): - raise TypeError("names_to_saveables must be a dict mapping string " - "names to Tensors/Variables. Not a variable: %s" % - variable) - if variable.op.type in ["Variable", "VariableV2", - "AutoReloadVariable"]: - yield BaseSaverBuilder.VariableSaveable(variable, "", name) - else: - yield BaseSaverBuilder.ResourceVariableSaveable( - variable, "", name) - - def _ValidateAndSliceInputs(self, names_to_saveables): - """Returns the variables and names that will be used for a Saver. - - Args: - names_to_saveables: A dict (k, v) where k is the name of an operation and - v is an operation to save or a BaseSaverBuilder.Saver. - - Returns: - A list of BaseSaverBuilder.SaveableObject objects. - - Raises: - TypeError: If any of the keys are not strings or any of the - values are not one of Tensor or Variable or a checkpointable operation. - ValueError: If the same operation is given in more than one value - (this also applies to slices of SlicedVariables). - """ - if not isinstance(names_to_saveables, dict): - names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables) - - saveables = [] - seen_ops = set() - for name, op in sorted(names_to_saveables.items(), - # Avoid comparing ops, sort only by name. - key=lambda x: x[0]): - for converted_saveable_object in self.SaveableObjectsForOp(op, name): - self._AddSaveable(saveables, seen_ops, converted_saveable_object) - return saveables - - def _AddSaveable(self, saveables, seen_ops, saveable): - """Adds the saveable to the saveables list. - - Args: - saveables: List to append the SaveableObject to. - seen_ops: Set of the ops of the saveables already processed. Used to - check that each saveable is only saved once. - saveable: The saveable. - - Raises: - ValueError: If the saveable has already been processed. - """ - if saveable.op in seen_ops: - raise ValueError("The same saveable will be restored with two names: %s" % - saveable.name) - saveables.append(saveable) - seen_ops.add(saveable.op) - def build(self, names_to_saveables, reshape=False, @@ -775,7 +483,8 @@ class BaseSaverBuilder(object): raise ValueError("save and restore operations need to be built together " " when eager execution is not enabled.") - saveables = self._ValidateAndSliceInputs(names_to_saveables) + saveables = saveable_object_util.validate_and_slice_inputs( + names_to_saveables) if max_to_keep is None: max_to_keep = 0 @@ -1910,7 +1619,7 @@ def saver_from_object_based_checkpoint( if builder is None: builder = BulkSaverBuilder() - saveables = builder._ValidateAndSliceInputs(var_list) # pylint: disable=protected-access + saveables = saveable_object_util.validate_and_slice_inputs(var_list) current_names = set() for saveable in saveables: for spec in saveable.specs: diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD new file mode 100644 index 00000000000..67ccd59b88c --- /dev/null +++ b/tensorflow/python/training/saving/BUILD @@ -0,0 +1,55 @@ +# Description: +# Low-level utilities for reading and writing checkpoints. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "functional_saver", + srcs = ["functional_saver.py"], + srcs_version = "PY2AND3", + deps = [ + ":saveable_object", + ":saveable_object_util", + "//tensorflow/python/eager:def_function", + ], +) + +cuda_py_test( + name = "functional_saver_test", + size = "medium", + srcs = [ + "functional_saver_test.py", + ], + additional_deps = [ + ":functional_saver", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "saveable_object", + srcs = ["saveable_object.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "saveable_object_util", + srcs = ["saveable_object_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//tensorflow/python/training/checkpointable:base", + "@six_archive//:six", + ], +) diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py new file mode 100644 index 00000000000..7eed3336626 --- /dev/null +++ b/tensorflow/python/training/saving/functional_saver.py @@ -0,0 +1,101 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Saves and restore variables inside traced @tf.functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.training.saving import saveable_object_util + + +class Saver(object): + """A minimal utility class for saving and restoring checkpoints. + + Note that this is a low-level utility which stores Tensors in the keys + specified by `SaveableObject`s. Higher-level utilities for object-based + checkpointing are built on top of it. + """ + + def __init__(self, saveable_objects): + """Specify a list of `SaveableObject`s to save and restore. + + Args: + saveable_objects: A list of `SaveableObject`s. + """ + saveable_objects = list(saveable_objects) + for saveable in saveable_objects: + if not isinstance(saveable, saveable_object.SaveableObject): + raise ValueError( + "Saver expected a list of SaveableObjects, got %s." % (saveable,)) + self._saveable_objects = saveable_objects + + # TODO(b/120569892): Use tf.function here + def save(self, file_prefix): + """Save the saveable objects to a checkpoint with `file_prefix`. + + Args: + file_prefix: A string or scalar string Tensor containing the prefix to + save under. + Returns: + A scalar string Tensor containing `file_prefix` with control dependencies + on the save ops. + """ + tensor_names = [] + tensors = [] + tensor_slices = [] + for saveable in self._saveable_objects: + for spec in saveable.specs: + tensor_names.append(spec.name) + tensors.append(spec.tensor) + tensor_slices.append(spec.slice_spec) + with ops.control_dependencies( + [io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)]): + return array_ops.identity(file_prefix) + + # TODO(b/120569892): Use tf.function here + def restore(self, file_prefix): + """Restore the saveable objects from a checkpoint with `file_prefix`. + + Args: + file_prefix: A string or scalar string Tensor containing the prefix for + files to read from. + + Returns: + An operation which restores the `Saver`'s `SaveableObject`s when run, or + None if executing eagerly. + """ + restore_ops = [] + for saveable in self._saveable_objects: + if saveable.device: + device = saveable_object_util.set_cpu0(saveable.device) + else: + device = None + with ops.device(device): + tensors = [] + for spec in saveable.specs: + tensors.append( + io_ops.restore_v2( + file_prefix, + [spec.name], + [spec.slice_spec], + [spec.dtype])[0]) + restore_ops.append(saveable.restore(tensors, restored_shapes=None)) + return control_flow_ops.group(restore_ops) diff --git a/tensorflow/python/training/saving/functional_saver_test.py b/tensorflow/python/training/saving/functional_saver_test.py new file mode 100644 index 00000000000..40002255aac --- /dev/null +++ b/tensorflow/python/training/saving/functional_saver_test.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for the functional saver.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.saving import functional_saver +from tensorflow.python.training.saving import saveable_object_util + + +class SaverTest(test.TestCase): + + def test_resource_variable(self): + v1 = resource_variable_ops.ResourceVariable(2.) + saver = functional_saver.Saver( + saveable_object_util.saveable_objects_for_op(v1, "x")) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + save_path = saver.save(constant_op.constant(prefix)) + v1.assign(1.) + saver.restore(save_path) + self.assertEqual(2., self.evaluate(v1)) + + v2 = resource_variable_ops.ResourceVariable(3.) + second_saver = functional_saver.Saver( + saveable_object_util.saveable_objects_for_op(v2, "x")) + second_saver.restore(save_path) + self.assertEqual(2., self.evaluate(v2)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/saveable_object.py b/tensorflow/python/training/saving/saveable_object.py similarity index 100% rename from tensorflow/python/training/saveable_object.py rename to tensorflow/python/training/saving/saveable_object.py diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py new file mode 100644 index 00000000000..fa88d2c6ebd --- /dev/null +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -0,0 +1,340 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for working with and creating SaveableObjects.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.eager import context +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.saving import saveable_object + + +# Op names which identify variable reads which should be saved. +_VARIABLE_OPS = set(["Variable", + "VariableV2", + "AutoReloadVariable", + "VarHandleOp", + "ReadVariableOp"]) + + +def set_cpu0(device_string): + """Creates a new device string based on `device_string` but using /CPU:0. + + If the device is already on /CPU:0, this is a no-op. + + Args: + device_string: A device string. + + Returns: + A device string. + """ + parsed_device = pydev.DeviceSpec.from_string(device_string) + parsed_device.device_type = "CPU" + parsed_device.device_index = 0 + return parsed_device.to_string() + + +class ReferenceVariableSaveable(saveable_object.SaveableObject): + """SaveableObject implementation that handles reference variables.""" + + def __init__(self, var, slice_spec, name): + spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) + super(ReferenceVariableSaveable, self).__init__(var, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + restored_tensor = restored_tensors[0] + if restored_shapes is not None: + restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) + return state_ops.assign( + self.op, + restored_tensor, + validate_shape=restored_shapes is None and + self.op.get_shape().is_fully_defined()) + + +class ResourceVariableSaveable(saveable_object.SaveableObject): + """SaveableObject implementation that handles ResourceVariables.""" + + def __init__(self, var, slice_spec, name): + self._var_device = var.device + self._var_shape = var.shape + if isinstance(var, ops.Tensor): + self.handle_op = var.op.inputs[0] + tensor = var + elif isinstance(var, resource_variable_ops.ResourceVariable): + + def _read_variable_closure(v): + def f(): + with ops.device(v.device): + x = v.read_value() + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + with ops.device("/device:CPU:0"): + return array_ops.identity(x) + return f + + self.handle_op = var.handle + tensor = _read_variable_closure(var) + else: + raise ValueError( + "Saveable is neither a resource variable nor a read operation." + " Got: %s" % repr(var)) + spec = saveable_object.SaveSpec(tensor, slice_spec, name, + dtype=var.dtype) + super(ResourceVariableSaveable, self).__init__(var, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + restored_tensor = restored_tensors[0] + if restored_shapes is not None: + restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) + # Copy the restored tensor to the variable's device. + with ops.device(self._var_device): + restored_tensor = array_ops.identity(restored_tensor) + return resource_variable_ops.shape_safe_assign_variable_handle( + self.handle_op, self._var_shape, restored_tensor) + + +def _tensor_comes_from_variable(v): + return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS + + +def saveable_objects_for_op(op, name): + """Create `SaveableObject`s from an operation. + + Args: + op: A variable, operation, or SaveableObject to coerce into a + SaveableObject. + name: A string name for the SaveableObject. + + Yields: + `SaveableObject`s which together save/restore `op`. + + Raises: + TypeError: If `name` is not a string. + ValueError: For operations with no known conversion to SaveableObject. + """ + if not isinstance(name, six.string_types): + raise TypeError( + "names_to_saveables must be a dict mapping string names to " + "checkpointable operations. Name is not a string: %s" % name) + if isinstance(op, saveable_object.SaveableObject): + yield op + elif isinstance(op, (list, tuple, variables.PartitionedVariable)): + if isinstance(op, variables.PartitionedVariable): + op = list(op) + # A set of slices. + slice_name = None + # pylint: disable=protected-access + for variable in op: + if not isinstance(variable, variables.Variable): + raise ValueError("Slices must all be Variables: %s" % variable) + if not variable._save_slice_info: + raise ValueError("Slices must all be slices: %s" % variable) + if slice_name is None: + slice_name = variable._save_slice_info.full_name + elif slice_name != variable._save_slice_info.full_name: + raise ValueError( + "Slices must all be from the same tensor: %s != %s" % + (slice_name, variable._save_slice_info.full_name)) + if variable.op.type in ["Variable", "VariableV2", + "AutoReloadVariable"]: + yield ReferenceVariableSaveable( + variable, variable._save_slice_info.spec, name) + else: + yield ResourceVariableSaveable( + variable, variable._save_slice_info.spec, name) + # pylint: enable=protected-access + elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance( + op, variables.Variable): + # pylint: disable=protected-access + for attr, factory in op._gather_saveables_for_checkpoint().items(): + if attr == checkpointable.VARIABLE_VALUE_KEY: + # Keep original name for classes masquerading as variables. + full_name = name + else: + full_name = name + "_" + attr + op = (factory(full_name) if callable(factory) else factory) + for op in saveable_objects_for_op(op, op.name): + yield op + # pylint: enable=protected-access + else: + # A variable or tensor. + if isinstance(op, resource_variable_ops.ResourceVariable): + # pylint: disable=protected-access + if op._in_graph_mode: + variable = op._graph_element + else: + variable = op + # pylint: enable=protected-access + yield ResourceVariableSaveable(variable, "", name) + else: + with ops.init_scope(): + if context.executing_eagerly(): + raise ValueError("Can only save/restore ResourceVariables when " + "executing eagerly, got type: %s." % type(op)) + + variable = ops.internal_convert_to_tensor(op, as_ref=True) + if not _tensor_comes_from_variable(variable): + raise TypeError("names_to_saveables must be a dict mapping string " + "names to Tensors/Variables. Not a variable: %s" % + variable) + if variable.op.type in ["Variable", "VariableV2", + "AutoReloadVariable"]: + yield ReferenceVariableSaveable(variable, "", name) + else: + yield ResourceVariableSaveable( + variable, "", name) + + +def op_list_to_dict(op_list, convert_variable_to_tensor=True): + """Create a dictionary of names to operation lists. + + Args: + op_list: A list, tuple, or set of Variables or SaveableObjects. + convert_variable_to_tensor: Whether or not to convert single Variables + with no slice info into Tensors. + + Returns: + A dictionary of names to the operations that must be saved under + that name. Variables with save_slice_info are grouped together under the + same key in no particular order. + + Raises: + TypeError: If the type of op_list or its elements is not supported. + ValueError: If at least two saveables share the same name. + """ + if not isinstance(op_list, (list, tuple, set)): + raise TypeError("Variables to save should be passed in a dict or a " + "list: %s" % op_list) + # When ResourceVariables are converted to Tensors, read ops are added to the + # graph. Sorting the op_list ensures that the resulting graph is always + # constructed in a deterministic way: + op_list = sorted(op_list, key=lambda x: x.name) + names_to_saveables = {} + # pylint: disable=protected-access + for var in op_list: + if isinstance(var, saveable_object.SaveableObject): + names_to_saveables[var.name] = var + elif isinstance(var, variables.PartitionedVariable): + if var.name in names_to_saveables: + raise ValueError("At least two variables have the same name: %s" % + var.name) + names_to_saveables[var.name] = var + elif isinstance(var, variables.Variable) and var._save_slice_info: + name = var._save_slice_info.full_name + if name in names_to_saveables: + if not isinstance(names_to_saveables[name], list): + raise ValueError("Mixing slices and non-slices with the same name: " + "%s" % name) + names_to_saveables[name].append(var) + else: + names_to_saveables[name] = [var] + elif (isinstance(var, checkpointable.CheckpointableBase) + and not isinstance(var, variables.Variable)): + checkpointable_saveables = [ + (factory() if callable(factory) else factory) + for factory in var._gather_saveables_for_checkpoint().values()] + names_to_saveables.update( + op_list_to_dict(checkpointable_saveables)) + else: + if context.executing_eagerly(): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "Can only save/restore ResourceVariables when eager execution " + "is enabled, type: %s." % type(var)) + set_var = names_to_saveables.setdefault(var._shared_name, var) + if set_var is not var: + raise ValueError( + ("Two different ResourceVariable objects with the same " + "shared_name '%s' were passed to the Saver. This likely means " + "that they were created in different Graphs or isolation " + "contexts, and may not be checkpointed together.") % + (var._shared_name,)) + else: + if convert_variable_to_tensor: + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var._graph_element # pylint: disable=protected-access + else: + var = ops.internal_convert_to_tensor(var, as_ref=True) + if not _tensor_comes_from_variable(var): + raise TypeError("Variable to save is not a Variable: %s" % var) + if var.op.type == "ReadVariableOp": + name = var.op.inputs[0].op.name + else: + name = var.op.name + if name in names_to_saveables: + raise ValueError("At least two variables have the same name: %s" % + name) + names_to_saveables[name] = var + + # pylint: enable=protected-access + return names_to_saveables + + +def _add_saveable(saveables, seen_ops, saveable): + """Adds the saveable to the saveables list. + + Args: + saveables: List to append the SaveableObject to. + seen_ops: Set of the ops of the saveables already processed. Used to + check that each saveable is only saved once. + saveable: The saveable. + + Raises: + ValueError: If the saveable has already been processed. + """ + if saveable.op in seen_ops: + raise ValueError("The same saveable will be restored with two names: %s" % + saveable.name) + saveables.append(saveable) + seen_ops.add(saveable.op) + + +def validate_and_slice_inputs(names_to_saveables): + """Returns the variables and names that will be used for a Saver. + + Args: + names_to_saveables: A dict (k, v) where k is the name of an operation and + v is an operation to save or a BaseSaverBuilder.Saver. + + Returns: + A list of SaveableObjects. + + Raises: + TypeError: If any of the keys are not strings or any of the + values are not one of Tensor or Variable or a checkpointable operation. + ValueError: If the same operation is given in more than one value + (this also applies to slices of SlicedVariables). + """ + if not isinstance(names_to_saveables, dict): + names_to_saveables = op_list_to_dict(names_to_saveables) + + saveables = [] + seen_ops = set() + for name, op in sorted(names_to_saveables.items(), + # Avoid comparing ops, sort only by name. + key=lambda x: x[0]): + for converted_saveable_object in saveable_objects_for_op(op, name): + _add_saveable(saveables, seen_ops, converted_saveable_object) + return saveables diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index 8c97f101da8..1382b8ce72e 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_ops from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.util.tf_export import tf_export @@ -139,7 +139,7 @@ def _infer_var_name(var): Returns: Name of the `var` """ - name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var) + name_to_var_dict = saveable_object_util.op_list_to_dict(var) if len(name_to_var_dict) > 1: raise TypeError("`var` = %s passed as arg violates the constraints. " "name_to_var_dict = %s" % (var, name_to_var_dict))