diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 532e78b6343..01f624ac66f 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -53,7 +53,6 @@ from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat - # Op names which identify variable reads which should be saved. _VARIABLE_OPS = set(["Variable", "VariableV2", @@ -63,7 +62,7 @@ _VARIABLE_OPS = set(["Variable", def _set_cpu0(device_string): - """Creates a new device string based on `device_string` but using /CPU:0. + """Creates a new device string based on `device_string` but using /CPU:0. If the device is already on /CPU:0, this is a no-op. @@ -73,38 +72,38 @@ def _set_cpu0(device_string): Returns: A device string. """ - parsed_device = pydev.DeviceSpec.from_string(device_string) - parsed_device.device_type = "CPU" - parsed_device.device_index = 0 - return parsed_device.to_string() + parsed_device = pydev.DeviceSpec.from_string(device_string) + parsed_device.device_type = "CPU" + parsed_device.device_index = 0 + return parsed_device.to_string() class BaseSaverBuilder(object): - """Base class for Savers. + """Base class for Savers. Can be extended to create different Ops. """ - class SaveSpec(object): - """Class used to describe tensor slices that need to be saved.""" + class SaveSpec(object): + """Class used to describe tensor slices that need to be saved.""" - def __init__(self, tensor, slice_spec, name): - """Creates a `SaveSpec` object. + def __init__(self, tensor, slice_spec, name): + """Creates a `SaveSpec` object. Args: tensor: the tensor to save. slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`. name: the name to save the tensor under. """ - self.tensor = tensor - self.slice_spec = slice_spec - self.name = name + self.tensor = tensor + self.slice_spec = slice_spec + self.name = name - class SaveableObject(object): - """Base class for saving and restoring saveable objects.""" + class SaveableObject(object): + """Base class for saving and restoring saveable objects.""" - def __init__(self, op, specs, name): - """Creates a `SaveableObject` object. + def __init__(self, op, specs, name): + """Creates a `SaveableObject` object. Args: op: the "producer" object that this class wraps; it produces a list of @@ -113,14 +112,14 @@ class BaseSaverBuilder(object): save under this object. name: the name to save the object under. """ - self.op = op - self.specs = specs - self.name = name - # The device of this saveable. All tensors must be on the same device. - self.device = specs[0].tensor.device + self.op = op + self.specs = specs + self.name = name + # The device of this saveable. All tensors must be on the same device. + self.device = specs[0].tensor.device - def restore(self, restored_tensors, restored_shapes): - """Restores this object from 'restored_tensors'. + def restore(self, restored_tensors, restored_shapes): + """Restores this object from 'restored_tensors'. Args: restored_tensors: the tensors that were loaded from a checkpoint @@ -134,54 +133,54 @@ class BaseSaverBuilder(object): ValueError: If the object cannot be restored using the provided parameters. """ - # pylint: disable=unused-argument - raise ValueError("Calling an abstract method.") + # pylint: disable=unused-argument + raise ValueError("Calling an abstract method.") - class VariableSaveable(SaveableObject): - """SaveableObject implementation that handles Variables.""" + class VariableSaveable(SaveableObject): + """SaveableObject implementation that handles Variables.""" - def __init__(self, var, slice_spec, name): - spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) - super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name) + def __init__(self, var, slice_spec, name): + spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) + super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name) - def restore(self, restored_tensors, restored_shapes): - restored_tensor = restored_tensors[0] - if restored_shapes is not None: - restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) - return state_ops.assign( - self.op, - restored_tensor, - validate_shape=restored_shapes is None and - self.op.get_shape().is_fully_defined()) + def restore(self, restored_tensors, restored_shapes): + restored_tensor = restored_tensors[0] + if restored_shapes is not None: + restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) + return state_ops.assign( + self.op, + restored_tensor, + validate_shape=restored_shapes is None and + self.op.get_shape().is_fully_defined()) - class ResourceVariableSaveable(SaveableObject): - """SaveableObject implementation that handles ResourceVariables.""" + class ResourceVariableSaveable(SaveableObject): + """SaveableObject implementation that handles ResourceVariables.""" - def __init__(self, var, slice_spec, name): - if isinstance(var, ops.Tensor): - self.handle_op = var.op.inputs[0] - elif isinstance(var, resource_variable_ops.ResourceVariable): - self.handle_op = var.handle - else: - raise ValueError( - "Saveable is neither a resource variable nor a read operation." - " Got: %s" % repr(var)) - spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) - super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__( - var, [spec], name) + def __init__(self, var, slice_spec, name): + if isinstance(var, ops.Tensor): + self.handle_op = var.op.inputs[0] + elif isinstance(var, resource_variable_ops.ResourceVariable): + self.handle_op = var.handle + else: + raise ValueError( + "Saveable is neither a resource variable nor a read operation." + " Got: %s" % repr(var)) + spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) + super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__( + var, [spec], name) - def restore(self, restored_tensors, restored_shapes): - restored_tensor = restored_tensors[0] - if restored_shapes is not None: - restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) - return resource_variable_ops.assign_variable_op( - self.handle_op, restored_tensor) + def restore(self, restored_tensors, restored_shapes): + restored_tensor = restored_tensors[0] + if restored_shapes is not None: + restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) + return resource_variable_ops.assign_variable_op( + self.handle_op, restored_tensor) - def __init__(self, write_version=saver_pb2.SaverDef.V2): - self._write_version = write_version + def __init__(self, write_version=saver_pb2.SaverDef.V2): + self._write_version = write_version - def save_op(self, filename_tensor, saveables): - """Create an Op to save 'saveables'. + def save_op(self, filename_tensor, saveables): + """Create an Op to save 'saveables'. This is intended to be overridden by subclasses that want to generate different Ops. @@ -197,32 +196,32 @@ class BaseSaverBuilder(object): RuntimeError: (implementation detail) if "self._write_version" is an unexpected value. """ - # pylint: disable=protected-access - tensor_names = [] - tensors = [] - tensor_slices = [] - for saveable in saveables: - for spec in saveable.specs: - tensor_names.append(spec.name) - tensors.append(spec.tensor) - tensor_slices.append(spec.slice_spec) - if self._write_version == saver_pb2.SaverDef.V1: - return io_ops._save( - filename=filename_tensor, - tensor_names=tensor_names, - tensors=tensors, - tensor_slices=tensor_slices) - elif self._write_version == saver_pb2.SaverDef.V2: - # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix - # of a V2 checkpoint: e.g. "/fs/train/ckpt-/tmp/worker-". - return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, - tensors) - else: - raise RuntimeError("Unexpected write_version: " + self._write_version) + # pylint: disable=protected-access + tensor_names = [] + tensors = [] + tensor_slices = [] + for saveable in saveables: + for spec in saveable.specs: + tensor_names.append(spec.name) + tensors.append(spec.tensor) + tensor_slices.append(spec.slice_spec) + if self._write_version == saver_pb2.SaverDef.V1: + return io_ops._save( + filename=filename_tensor, + tensor_names=tensor_names, + tensors=tensors, + tensor_slices=tensor_slices) + elif self._write_version == saver_pb2.SaverDef.V2: + # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix + # of a V2 checkpoint: e.g. "/fs/train/ckpt-/tmp/worker-". + return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, + tensors) + else: + raise RuntimeError("Unexpected write_version: " + self._write_version) - # pylint: disable=unused-argument - def restore_op(self, filename_tensor, saveable, preferred_shard): - """Create ops to restore 'saveable'. + # pylint: disable=unused-argument + def restore_op(self, filename_tensor, saveable, preferred_shard): + """Create ops to restore 'saveable'. This is intended to be overridden by subclasses that want to generate different Ops. @@ -236,21 +235,22 @@ class BaseSaverBuilder(object): A list of Tensors resulting from reading 'saveable' from 'filename'. """ - # pylint: disable=protected-access - tensors = [] - for spec in saveable.specs: - tensors.append( - io_ops.restore_v2( - filename_tensor, - [spec.name], - [spec.slice_spec], - [spec.tensor.dtype])[0]) + # pylint: disable=protected-access + tensors = [] + for spec in saveable.specs: + tensors.append( + io_ops.restore_v2( + filename_tensor, + [spec.name], + [spec.slice_spec], + [spec.tensor.dtype])[0]) - return tensors - # pylint: enable=unused-argument + return tensors - def sharded_filename(self, filename_tensor, shard, num_shards): - """Append sharding information to a filename. + # pylint: enable=unused-argument + + def sharded_filename(self, filename_tensor, shard, num_shards): + """Append sharding information to a filename. Args: filename_tensor: A string tensor. @@ -260,11 +260,11 @@ class BaseSaverBuilder(object): Returns: A string tensor. """ - # pylint: disable=protected-access - return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards) + # pylint: disable=protected-access + return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards) - def _AddSaveOps(self, filename_tensor, saveables): - """Add ops to save variables that are on the same shard. + def _AddSaveOps(self, filename_tensor, saveables): + """Add ops to save variables that are on the same shard. Args: filename_tensor: String Tensor. @@ -273,11 +273,11 @@ class BaseSaverBuilder(object): Returns: A tensor with the filename used to save. """ - save = self.save_op(filename_tensor, saveables) - return control_flow_ops.with_dependencies([save], filename_tensor) + save = self.save_op(filename_tensor, saveables) + return control_flow_ops.with_dependencies([save], filename_tensor) - def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): - """Add ops to save the params per shard, for the V2 format. + def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): + """Add ops to save the params per shard, for the V2 format. Note that the sharded save procedure for the V2 format is different from V1: there is a special "merge" step that merges the small metadata produced @@ -293,61 +293,67 @@ class BaseSaverBuilder(object): An op to save the variables, which, when evaluated, returns the prefix "" only and does not include the sharded spec suffix. """ - # IMPLEMENTATION DETAILS: most clients should skip. - # - # Suffix for any well-formed "checkpoint_prefix", when sharded. - # Transformations: - # * Users pass in "save_path" in save() and restore(). Say "myckpt". - # * checkpoint_prefix gets fed <_SHARDED_SUFFIX>. - # - # Example: - # During runtime, a temporary directory is first created, which contains - # files - # - # /myckpt_temp/ - # part-?????-of-?????{.index, .data-00000-of-00001} - # - # Before .save() finishes, they will be (hopefully, atomically) renamed to - # - # / - # myckpt{.index, .data-?????-of-?????} - # - # Users only need to interact with the user-specified prefix, which is - # "/myckpt" in this case. Save() and Restore() work with the - # prefix directly, instead of any physical pathname. (On failure and - # subsequent restore, an outdated and orphaned temporary directory can be - # safely removed.) - _SHARDED_SUFFIX = os.path.normpath("_temp_%s/part" % uuid.uuid4().hex) - tmp_checkpoint_prefix = string_ops.string_join( - [checkpoint_prefix, _SHARDED_SUFFIX]) + # IMPLEMENTATION DETAILS: most clients should skip. + # + # Suffix for any well-formed "checkpoint_prefix", when sharded. + # Transformations: + # * Users pass in "save_path" in save() and restore(). Say "myckpt". + # * checkpoint_prefix gets fed <_SHARDED_SUFFIX>. + # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it + # * Otherwise _temp/part is appended which is normalized relative to the OS + # + # Example: + # During runtime, a temporary directory is first created, which contains + # files + # + # /myckpt_temp/ + # part-?????-of-?????{.index, .data-00000-of-00001} + # + # Before .save() finishes, they will be (hopefully, atomically) renamed to + # + # / + # myckpt{.index, .data-?????-of-?????} + # + # Users only need to interact with the user-specified prefix, which is + # "/myckpt" in this case. Save() and Restore() work with the + # prefix directly, instead of any physical pathname. (On failure and + # subsequent restore, an outdated and orphaned temporary directory can be + # safely removed.) + with ops.device("CPU"): + _SHARDED_SUFFIX = array_ops.where( + string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), + constant_op.constant(".part"), + constant_op.constant(os.path.normpath("_temp/part"))) + tmp_checkpoint_prefix = string_ops.string_join( + [checkpoint_prefix, _SHARDED_SUFFIX]) - num_shards = len(per_device) - sharded_saves = [] - sharded_prefixes = [] - num_shards_tensor = constant_op.constant(num_shards, name="num_shards") - last_device = None - for shard, (device, saveables) in enumerate(per_device): - last_device = device - with ops.device(device): - sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, - num_shards_tensor) - sharded_prefixes.append(sharded_filename) - sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) + num_shards = len(per_device) + sharded_saves = [] + sharded_prefixes = [] + num_shards_tensor = constant_op.constant(num_shards, name="num_shards") + last_device = None + for shard, (device, saveables) in enumerate(per_device): + last_device = device + with ops.device(device): + sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, + num_shards_tensor) + sharded_prefixes.append(sharded_filename) + sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) - with ops.control_dependencies([x.op for x in sharded_saves]): - # Co-locates the merge step with the last device. - with ops.device(last_device): - # V2 format write path consists of a metadata merge step. Once merged, - # attempts to delete the temporary directory, "_temp". - merge_step = gen_io_ops.merge_v2_checkpoints( - sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) - with ops.control_dependencies([merge_step]): - # Returns the prefix "" only. DOES NOT include the - # sharded spec suffix. - return array_ops.identity(checkpoint_prefix) + with ops.control_dependencies([x.op for x in sharded_saves]): + # Co-locates the merge step with the last device. + with ops.device(last_device): + # V2 format write path consists of a metadata merge step. Once merged, + # attempts to delete the temporary directory, "_temp". + merge_step = gen_io_ops.merge_v2_checkpoints( + sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) + with ops.control_dependencies([merge_step]): + # Returns the prefix "" only. DOES NOT include the + # sharded spec suffix. + return array_ops.identity(checkpoint_prefix) - def _AddShardedSaveOps(self, filename_tensor, per_device): - """Add ops to save the params per shard. + def _AddShardedSaveOps(self, filename_tensor, per_device): + """Add ops to save the params per shard. Args: filename_tensor: a scalar String Tensor. @@ -357,30 +363,30 @@ class BaseSaverBuilder(object): Returns: An op to save the variables. """ - if self._write_version == saver_pb2.SaverDef.V2: - return self._AddShardedSaveOpsForV2(filename_tensor, per_device) + if self._write_version == saver_pb2.SaverDef.V2: + return self._AddShardedSaveOpsForV2(filename_tensor, per_device) - num_shards = len(per_device) - sharded_saves = [] - num_shards_tensor = constant_op.constant(num_shards, name="num_shards") - for shard, (device, saveables) in enumerate(per_device): - with ops.device(device): - sharded_filename = self.sharded_filename(filename_tensor, shard, - num_shards_tensor) - sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) - # Return the sharded name for the save path. - with ops.control_dependencies([x.op for x in sharded_saves]): - # pylint: disable=protected-access - return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor) + num_shards = len(per_device) + sharded_saves = [] + num_shards_tensor = constant_op.constant(num_shards, name="num_shards") + for shard, (device, saveables) in enumerate(per_device): + with ops.device(device): + sharded_filename = self.sharded_filename(filename_tensor, shard, + num_shards_tensor) + sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) + # Return the sharded name for the save path. + with ops.control_dependencies([x.op for x in sharded_saves]): + # pylint: disable=protected-access + return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor) - def _AddRestoreOps(self, - filename_tensor, - saveables, - restore_sequentially, - reshape, - preferred_shard=-1, - name="restore_all"): - """Add operations to restore saveables. + def _AddRestoreOps(self, + filename_tensor, + saveables, + restore_sequentially, + reshape, + preferred_shard=-1, + name="restore_all"): + """Add operations to restore saveables. Args: filename_tensor: Tensor for the path of the file to load. @@ -395,35 +401,35 @@ class BaseSaverBuilder(object): Returns: An Operation that restores the variables. """ - assign_ops = [] - for saveable in saveables: - restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] - # Load and optionally reshape on the CPU, as string tensors are not - # available on the GPU. - # TODO(touts): Re-enable restore on GPU when we can support annotating - # string tensors as "HostMemory" inputs. - with ops.device(_set_cpu0(saveable.device) if saveable.device else None): - with ops.control_dependencies(restore_control_inputs): - tensors = self.restore_op(filename_tensor, saveable, preferred_shard) - shapes = None - if reshape: - # Compute the shapes, let the restore op decide if and how to do - # the reshape. - shapes = [] - for spec in saveable.specs: - v = spec.tensor - shape = v.get_shape() - if not shape.is_fully_defined(): - shape = array_ops.shape(v) - shapes.append(shape) - assign_ops.append(saveable.restore(tensors, shapes)) + assign_ops = [] + for saveable in saveables: + restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] + # Load and optionally reshape on the CPU, as string tensors are not + # available on the GPU. + # TODO(touts): Re-enable restore on GPU when we can support annotating + # string tensors as "HostMemory" inputs. + with ops.device(_set_cpu0(saveable.device) if saveable.device else None): + with ops.control_dependencies(restore_control_inputs): + tensors = self.restore_op(filename_tensor, saveable, preferred_shard) + shapes = None + if reshape: + # Compute the shapes, let the restore op decide if and how to do + # the reshape. + shapes = [] + for spec in saveable.specs: + v = spec.tensor + shape = v.get_shape() + if not shape.is_fully_defined(): + shape = array_ops.shape(v) + shapes.append(shape) + assign_ops.append(saveable.restore(tensors, shapes)) - # Create a Noop that has control dependencies from all the updates. - return control_flow_ops.group(*assign_ops, name=name) + # Create a Noop that has control dependencies from all the updates. + return control_flow_ops.group(*assign_ops, name=name) - def _AddShardedRestoreOps(self, filename_tensor, per_device, - restore_sequentially, reshape): - """Add Ops to restore variables from multiple devices. + def _AddShardedRestoreOps(self, filename_tensor, per_device, + restore_sequentially, reshape): + """Add Ops to restore variables from multiple devices. Args: filename_tensor: Tensor for the path of the file to load. @@ -437,25 +443,25 @@ class BaseSaverBuilder(object): Returns: An Operation that restores the variables. """ - sharded_restores = [] - for shard, (device, saveables) in enumerate(per_device): - with ops.device(device): - sharded_restores.append( - self._AddRestoreOps( - filename_tensor, - saveables, - restore_sequentially, - reshape, - preferred_shard=shard, - name="restore_shard")) - return control_flow_ops.group(*sharded_restores, name="restore_all") + sharded_restores = [] + for shard, (device, saveables) in enumerate(per_device): + with ops.device(device): + sharded_restores.append( + self._AddRestoreOps( + filename_tensor, + saveables, + restore_sequentially, + reshape, + preferred_shard=shard, + name="restore_shard")) + return control_flow_ops.group(*sharded_restores, name="restore_all") - @staticmethod - def _IsVariable(v): - return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS + @staticmethod + def _IsVariable(v): + return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS - def _GroupByDevices(self, saveables): - """Group Variable tensor slices per device. + def _GroupByDevices(self, saveables): + """Group Variable tensor slices per device. TODO(touts): Make sure that all the devices found are on different job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. @@ -471,19 +477,19 @@ class BaseSaverBuilder(object): Raises: ValueError: If the tensors of a saveable are on different devices. """ - per_device = collections.defaultdict(lambda: []) - for saveable in saveables: - canonical_device = set( - pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) - if len(canonical_device) != 1: - raise ValueError("All tensors of a saveable object must be " - "on the same device: %s" % saveable.name) - per_device[canonical_device.pop()].append(saveable) - return sorted(per_device.items(), key=lambda t: t[0]) + per_device = collections.defaultdict(lambda: []) + for saveable in saveables: + canonical_device = set( + pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) + if len(canonical_device) != 1: + raise ValueError("All tensors of a saveable object must be " + "on the same device: %s" % saveable.name) + per_device[canonical_device.pop()].append(saveable) + return sorted(per_device.items(), key=lambda t: t[0]) - @staticmethod - def OpListToDict(op_list): - """Create a dictionary of names to operation lists. + @staticmethod + def OpListToDict(op_list): + """Create a dictionary of names to operation lists. Args: op_list: A list, tuple, or set of Variables or SaveableObjects. @@ -497,48 +503,48 @@ class BaseSaverBuilder(object): TypeError: If the type of op_list or its elements is not supported. ValueError: If at least two saveables share the same name. """ - if not isinstance(op_list, (list, tuple, set)): - raise TypeError("Variables to save should be passed in a dict or a " - "list: %s" % op_list) - op_list = set(op_list) - names_to_saveables = {} - # pylint: disable=protected-access - for var in op_list: - if isinstance(var, BaseSaverBuilder.SaveableObject): - names_to_saveables[var.name] = var - elif isinstance(var, variables.PartitionedVariable): - if var.name in names_to_saveables: - raise ValueError("At least two variables have the same name: %s" % - var.name) - names_to_saveables[var.name] = var - elif ((isinstance(var, variables.Variable) or - isinstance(var, resource_variable_ops.ResourceVariable)) and - var._save_slice_info): - name = var._save_slice_info.full_name - if name in names_to_saveables: - if not isinstance(names_to_saveables[name], list): - raise ValueError("Mixing slices and non-slices with the same name: " - "%s" % name) - names_to_saveables[name].append(var) - else: - names_to_saveables[name] = [var] - else: - var = ops.internal_convert_to_tensor(var, as_ref=True) - if not BaseSaverBuilder._IsVariable(var): - raise TypeError("Variable to save is not a Variable: %s" % var) - if var.op.type == "ReadVariableOp": - name = var.op.inputs[0].op.name - else: - name = var.op.name - if name in names_to_saveables: - raise ValueError("At least two variables have the same name: %s" % - name) - names_to_saveables[name] = var - # pylint: enable=protected-access - return names_to_saveables + if not isinstance(op_list, (list, tuple, set)): + raise TypeError("Variables to save should be passed in a dict or a " + "list: %s" % op_list) + op_list = set(op_list) + names_to_saveables = {} + # pylint: disable=protected-access + for var in op_list: + if isinstance(var, BaseSaverBuilder.SaveableObject): + names_to_saveables[var.name] = var + elif isinstance(var, variables.PartitionedVariable): + if var.name in names_to_saveables: + raise ValueError("At least two variables have the same name: %s" % + var.name) + names_to_saveables[var.name] = var + elif ((isinstance(var, variables.Variable) or + isinstance(var, resource_variable_ops.ResourceVariable)) and + var._save_slice_info): + name = var._save_slice_info.full_name + if name in names_to_saveables: + if not isinstance(names_to_saveables[name], list): + raise ValueError("Mixing slices and non-slices with the same name: " + "%s" % name) + names_to_saveables[name].append(var) + else: + names_to_saveables[name] = [var] + else: + var = ops.internal_convert_to_tensor(var, as_ref=True) + if not BaseSaverBuilder._IsVariable(var): + raise TypeError("Variable to save is not a Variable: %s" % var) + if var.op.type == "ReadVariableOp": + name = var.op.inputs[0].op.name + else: + name = var.op.name + if name in names_to_saveables: + raise ValueError("At least two variables have the same name: %s" % + name) + names_to_saveables[name] = var + # pylint: enable=protected-access + return names_to_saveables - def _ValidateAndSliceInputs(self, names_to_saveables): - """Returns the variables and names that will be used for a Saver. + def _ValidateAndSliceInputs(self, names_to_saveables): + """Returns the variables and names that will be used for a Saver. Args: names_to_saveables: A dict (k, v) where k is the name of an operation and @@ -553,63 +559,63 @@ class BaseSaverBuilder(object): ValueError: If the same operation is given in more than one value (this also applies to slices of SlicedVariables). """ - if not isinstance(names_to_saveables, dict): - names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables) + if not isinstance(names_to_saveables, dict): + names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables) - saveables = [] - seen_ops = set() - for name in sorted(names_to_saveables.keys()): - if not isinstance(name, six.string_types): - raise TypeError( - "names_to_saveables must be a dict mapping string names to " - "checkpointable operations. Name is not a string: %s" % name) - op = names_to_saveables[name] - if isinstance(op, BaseSaverBuilder.SaveableObject): - self._AddSaveable(saveables, seen_ops, op) - elif isinstance(op, (list, tuple, variables.PartitionedVariable)): - if isinstance(op, variables.PartitionedVariable): - op = list(op) - # A set of slices. - slice_name = None - # pylint: disable=protected-access - for variable in op: - if (not isinstance(variable, variables.Variable) and - not isinstance(variable, resource_variable_ops.ResourceVariable)): - raise ValueError("Slices must all be Variables: %s" % variable) - if not variable._save_slice_info: - raise ValueError("Slices must all be slices: %s" % variable) - if slice_name is None: - slice_name = variable._save_slice_info.full_name - elif slice_name != variable._save_slice_info.full_name: - raise ValueError( - "Slices must all be from the same tensor: %s != %s" % - (slice_name, variable._save_slice_info.full_name)) - if variable.op.type in ["Variable", "VariableV2", - "AutoReloadVariable"]: - saveable = BaseSaverBuilder.VariableSaveable( - variable, variable._save_slice_info.spec, name) - else: - saveable = BaseSaverBuilder.ResourceVariableSaveable( - variable, variable._save_slice_info.spec, name) - self._AddSaveable(saveables, seen_ops, saveable) - # pylint: enable=protected-access - else: - # A variable or tensor. - variable = ops.internal_convert_to_tensor(op, as_ref=True) - if not BaseSaverBuilder._IsVariable(variable): - raise TypeError("names_to_saveables must be a dict mapping string " - "names to Tensors/Variables. Not a variable: %s" % - variable) - if variable.op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: - saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) - else: - saveable = BaseSaverBuilder.ResourceVariableSaveable( - variable, "", name) - self._AddSaveable(saveables, seen_ops, saveable) - return saveables + saveables = [] + seen_ops = set() + for name in sorted(names_to_saveables.keys()): + if not isinstance(name, six.string_types): + raise TypeError( + "names_to_saveables must be a dict mapping string names to " + "checkpointable operations. Name is not a string: %s" % name) + op = names_to_saveables[name] + if isinstance(op, BaseSaverBuilder.SaveableObject): + self._AddSaveable(saveables, seen_ops, op) + elif isinstance(op, (list, tuple, variables.PartitionedVariable)): + if isinstance(op, variables.PartitionedVariable): + op = list(op) + # A set of slices. + slice_name = None + # pylint: disable=protected-access + for variable in op: + if (not isinstance(variable, variables.Variable) and + not isinstance(variable, resource_variable_ops.ResourceVariable)): + raise ValueError("Slices must all be Variables: %s" % variable) + if not variable._save_slice_info: + raise ValueError("Slices must all be slices: %s" % variable) + if slice_name is None: + slice_name = variable._save_slice_info.full_name + elif slice_name != variable._save_slice_info.full_name: + raise ValueError( + "Slices must all be from the same tensor: %s != %s" % + (slice_name, variable._save_slice_info.full_name)) + if variable.op.type in ["Variable", "VariableV2", + "AutoReloadVariable"]: + saveable = BaseSaverBuilder.VariableSaveable( + variable, variable._save_slice_info.spec, name) + else: + saveable = BaseSaverBuilder.ResourceVariableSaveable( + variable, variable._save_slice_info.spec, name) + self._AddSaveable(saveables, seen_ops, saveable) + # pylint: enable=protected-access + else: + # A variable or tensor. + variable = ops.internal_convert_to_tensor(op, as_ref=True) + if not BaseSaverBuilder._IsVariable(variable): + raise TypeError("names_to_saveables must be a dict mapping string " + "names to Tensors/Variables. Not a variable: %s" % + variable) + if variable.op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: + saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) + else: + saveable = BaseSaverBuilder.ResourceVariableSaveable( + variable, "", name) + self._AddSaveable(saveables, seen_ops, saveable) + return saveables - def _AddSaveable(self, saveables, seen_ops, saveable): - """Adds the saveable to the saveables list. + def _AddSaveable(self, saveables, seen_ops, saveable): + """Adds the saveable to the saveables list. Args: saveables: List to append the SaveableObject to. @@ -620,22 +626,22 @@ class BaseSaverBuilder(object): Raises: ValueError: If the saveable has already been processed. """ - if saveable.op in seen_ops: - raise ValueError("The same saveable will be restored with two names: %s" % - saveable.name) - saveables.append(saveable) - seen_ops.add(saveable.op) + if saveable.op in seen_ops: + raise ValueError("The same saveable will be restored with two names: %s" % + saveable.name) + saveables.append(saveable) + seen_ops.add(saveable.op) - def build(self, - names_to_saveables, - reshape=False, - sharded=False, - max_to_keep=5, - keep_checkpoint_every_n_hours=10000.0, - name=None, - restore_sequentially=False, - filename="model"): - """Adds save/restore nodes to the graph and creates a SaverDef proto. + def build(self, + names_to_saveables, + reshape=False, + sharded=False, + max_to_keep=5, + keep_checkpoint_every_n_hours=10000.0, + name=None, + restore_sequentially=False, + filename="model"): + """Adds save/restore nodes to the graph and creates a SaverDef proto. Args: names_to_saveables: A dictionary mapping name to a Variable or @@ -670,50 +676,50 @@ class BaseSaverBuilder(object): ValueError: If any of the keys or values in 'names_to_saveables' is not unique. """ - saveables = self._ValidateAndSliceInputs(names_to_saveables) - if max_to_keep is None: - max_to_keep = 0 + saveables = self._ValidateAndSliceInputs(names_to_saveables) + if max_to_keep is None: + max_to_keep = 0 - with ops.name_scope(name, "save", - [saveable.op for saveable in saveables]) as name: - # Add the Constant string tensor for the filename. - filename_tensor = constant_op.constant(filename) + with ops.name_scope(name, "save", + [saveable.op for saveable in saveables]) as name: + # Add the Constant string tensor for the filename. + filename_tensor = constant_op.constant(filename) - # Add the save ops. - if sharded: - per_device = self._GroupByDevices(saveables) - save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) - restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, - restore_sequentially, reshape) - else: - save_tensor = self._AddSaveOps(filename_tensor, saveables) - restore_op = self._AddRestoreOps(filename_tensor, saveables, - restore_sequentially, reshape) + # Add the save ops. + if sharded: + per_device = self._GroupByDevices(saveables) + save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) + restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, + restore_sequentially, reshape) + else: + save_tensor = self._AddSaveOps(filename_tensor, saveables) + restore_op = self._AddRestoreOps(filename_tensor, saveables, + restore_sequentially, reshape) - # In the following use case, it's possible to have restore_ops be called - # something else: - # - Build inference graph and export a meta_graph. - # - Import the inference meta_graph - # - Extend the inference graph to a train graph. - # - Export a new meta_graph. - # Now the second restore_op will be called "restore_all_1". - # As such, comment out the assert for now until we know whether supporting - # such usage model makes sense. - # - # assert restore_op.name.endswith("restore_all"), restore_op.name + # In the following use case, it's possible to have restore_ops be called + # something else: + # - Build inference graph and export a meta_graph. + # - Import the inference meta_graph + # - Extend the inference graph to a train graph. + # - Export a new meta_graph. + # Now the second restore_op will be called "restore_all_1". + # As such, comment out the assert for now until we know whether supporting + # such usage model makes sense. + # + # assert restore_op.name.endswith("restore_all"), restore_op.name - return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.name, - save_tensor_name=save_tensor.name, - restore_op_name=restore_op.name, - max_to_keep=max_to_keep, - sharded=sharded, - keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, - version=self._write_version) + return saver_pb2.SaverDef( + filename_tensor_name=filename_tensor.name, + save_tensor_name=save_tensor.name, + restore_op_name=restore_op.name, + max_to_keep=max_to_keep, + sharded=sharded, + keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, + version=self._write_version) def _GetCheckpointFilename(save_dir, latest_filename): - """Returns a filename for storing the CheckpointState. + """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. @@ -723,15 +729,15 @@ def _GetCheckpointFilename(save_dir, latest_filename): Returns: The path of the file that contains the CheckpointState proto. """ - if latest_filename is None: - latest_filename = "checkpoint" - return os.path.join(save_dir, latest_filename) + if latest_filename is None: + latest_filename = "checkpoint" + return os.path.join(save_dir, latest_filename) def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): - """Generates a checkpoint state proto. + """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. @@ -746,37 +752,37 @@ def generate_checkpoint_state_proto(save_dir, all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ - if all_model_checkpoint_paths is None: - all_model_checkpoint_paths = [] + if all_model_checkpoint_paths is None: + all_model_checkpoint_paths = [] - if (not all_model_checkpoint_paths or - all_model_checkpoint_paths[-1] != model_checkpoint_path): - logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", - model_checkpoint_path) - all_model_checkpoint_paths.append(model_checkpoint_path) + if (not all_model_checkpoint_paths or + all_model_checkpoint_paths[-1] != model_checkpoint_path): + logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", + model_checkpoint_path) + all_model_checkpoint_paths.append(model_checkpoint_path) - # Relative paths need to be rewritten to be relative to the "save_dir" - # if model_checkpoint_path already contains "save_dir". - if not os.path.isabs(save_dir): - if not os.path.isabs(model_checkpoint_path): - model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) - for i in range(len(all_model_checkpoint_paths)): - p = all_model_checkpoint_paths[i] - if not os.path.isabs(p): - all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) + # Relative paths need to be rewritten to be relative to the "save_dir" + # if model_checkpoint_path already contains "save_dir". + if not os.path.isabs(save_dir): + if not os.path.isabs(model_checkpoint_path): + model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) + for i in range(len(all_model_checkpoint_paths)): + p = all_model_checkpoint_paths[i] + if not os.path.isabs(p): + all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) - coord_checkpoint_proto = CheckpointState( - model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) + coord_checkpoint_proto = CheckpointState( + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths) - return coord_checkpoint_proto + return coord_checkpoint_proto def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): - """Updates the content of the 'checkpoint' file. + """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. @@ -794,26 +800,26 @@ def update_checkpoint_state(save_dir, Raises: RuntimeError: If the save paths conflict. """ - # Writes the "checkpoint" file for the coordinator for later restoration. - coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) - ckpt = generate_checkpoint_state_proto( - save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) + # Writes the "checkpoint" file for the coordinator for later restoration. + coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) + ckpt = generate_checkpoint_state_proto( + save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths) - if coord_checkpoint_filename == ckpt.model_checkpoint_path: - raise RuntimeError("Save path '%s' conflicts with path used for " - "checkpoint state. Please use a different save path." % - model_checkpoint_path) + if coord_checkpoint_filename == ckpt.model_checkpoint_path: + raise RuntimeError("Save path '%s' conflicts with path used for " + "checkpoint state. Please use a different save path." % + model_checkpoint_path) - # Preventing potential read/write race condition by *atomically* writing to a - # file. - file_io.atomic_write_string_to_file(coord_checkpoint_filename, - text_format.MessageToString(ckpt)) + # Preventing potential read/write race condition by *atomically* writing to a + # file. + file_io.atomic_write_string_to_file(coord_checkpoint_filename, + text_format.MessageToString(ckpt)) def get_checkpoint_state(checkpoint_dir, latest_filename=None): - """Returns CheckpointState proto from the "checkpoint" file. + """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. @@ -830,47 +836,47 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ - ckpt = None - coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, - latest_filename) - f = None - try: - # Check that the file exists before opening it to avoid - # many lines of errors from colossus in the logs. - if file_io.file_exists(coord_checkpoint_filename): - file_content = file_io.read_file_to_string( - coord_checkpoint_filename) - ckpt = CheckpointState() - text_format.Merge(file_content, ckpt) - if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) - # For relative model_checkpoint_path and all_model_checkpoint_paths, - # prepend checkpoint_dir. - if not os.path.isabs(ckpt.model_checkpoint_path): - ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, - ckpt.model_checkpoint_path) - for i in range(len(ckpt.all_model_checkpoint_paths)): - p = ckpt.all_model_checkpoint_paths[i] - if not os.path.isabs(p): - ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) - except errors.OpError as e: - # It's ok if the file cannot be read - logging.warning(str(e)) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - except text_format.ParseError as e: - logging.warning(str(e)) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - finally: - if f: - f.close() - return ckpt + ckpt = None + coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, + latest_filename) + f = None + try: + # Check that the file exists before opening it to avoid + # many lines of errors from colossus in the logs. + if file_io.file_exists(coord_checkpoint_filename): + file_content = file_io.read_file_to_string( + coord_checkpoint_filename) + ckpt = CheckpointState() + text_format.Merge(file_content, ckpt) + if not ckpt.model_checkpoint_path: + raise ValueError("Invalid checkpoint state loaded from %s", + checkpoint_dir) + # For relative model_checkpoint_path and all_model_checkpoint_paths, + # prepend checkpoint_dir. + if not os.path.isabs(ckpt.model_checkpoint_path): + ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, + ckpt.model_checkpoint_path) + for i in range(len(ckpt.all_model_checkpoint_paths)): + p = ckpt.all_model_checkpoint_paths[i] + if not os.path.isabs(p): + ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) + except errors.OpError as e: + # It's ok if the file cannot be read + logging.warning(str(e)) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + except text_format.ParseError as e: + logging.warning(str(e)) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + finally: + if f: + f.close() + return ckpt class Saver(object): - """Saves and restores variables. + """Saves and restores variables. See @{$variables$Variables} for an overview of variables, saving and restoring. @@ -943,21 +949,21 @@ class Saver(object): protocol buffer file in the call to `save()`. """ - def __init__(self, - var_list=None, - reshape=False, - sharded=False, - max_to_keep=5, - keep_checkpoint_every_n_hours=10000.0, - name=None, - restore_sequentially=False, - saver_def=None, - builder=None, - defer_build=False, - allow_empty=False, - write_version=saver_pb2.SaverDef.V2, - pad_step_number=False): - """Creates a `Saver`. + def __init__(self, + var_list=None, + reshape=False, + sharded=False, + max_to_keep=5, + keep_checkpoint_every_n_hours=10000.0, + name=None, + restore_sequentially=False, + saver_def=None, + builder=None, + defer_build=False, + allow_empty=False, + write_version=saver_pb2.SaverDef.V2, + pad_step_number=False): + """Creates a `Saver`. The constructor adds ops to save and restore variables. @@ -1034,86 +1040,86 @@ class Saver(object): TypeError: If `var_list` is invalid. ValueError: If any of the keys or values in `var_list` are not unique. """ - if defer_build and var_list: - raise ValueError( - "If `var_list` is provided then build cannot be deferred. " - "Either set defer_build=False or var_list=None.") - self._var_list = var_list - self._reshape = reshape - self._sharded = sharded - self._max_to_keep = max_to_keep - self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours - self._name = name - self._restore_sequentially = restore_sequentially - self.saver_def = saver_def - self._builder = builder - self._is_built = False - self._allow_empty = allow_empty - self._is_empty = None - self._write_version = write_version - self._pad_step_number = pad_step_number - if not defer_build: - self.build() - if self.saver_def: - self._check_saver_def() - self._write_version = self.saver_def.version + if defer_build and var_list: + raise ValueError( + "If `var_list` is provided then build cannot be deferred. " + "Either set defer_build=False or var_list=None.") + self._var_list = var_list + self._reshape = reshape + self._sharded = sharded + self._max_to_keep = max_to_keep + self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours + self._name = name + self._restore_sequentially = restore_sequentially + self.saver_def = saver_def + self._builder = builder + self._is_built = False + self._allow_empty = allow_empty + self._is_empty = None + self._write_version = write_version + self._pad_step_number = pad_step_number + if not defer_build: + self.build() + if self.saver_def: + self._check_saver_def() + self._write_version = self.saver_def.version - def build(self): - """Builds saver_def.""" - if self._is_built: - return - self._is_built = True - if not self.saver_def: - if self._builder is None: - self._builder = BaseSaverBuilder(self._write_version) - if self._var_list is None: - # pylint: disable=protected-access - self._var_list = variables._all_saveable_objects() - if not self._var_list: - if self._allow_empty: - self._is_empty = True - return - else: - raise ValueError("No variables to save") - self._is_empty = False - self.saver_def = self._builder.build( - self._var_list, - reshape=self._reshape, - sharded=self._sharded, - max_to_keep=self._max_to_keep, - keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, - name=self._name, - restore_sequentially=self._restore_sequentially) - elif self.saver_def and self._name: - # Since self._name is used as a name_scope by builder(), we are - # overloading the use of this field to represent the "import_scope" as - # well. - self.saver_def.filename_tensor_name = ops.prepend_name_scope( - self.saver_def.filename_tensor_name, self._name) - self.saver_def.save_tensor_name = ops.prepend_name_scope( - self.saver_def.save_tensor_name, self._name) - self.saver_def.restore_op_name = ops.prepend_name_scope( - self.saver_def.restore_op_name, self._name) + def build(self): + """Builds saver_def.""" + if self._is_built: + return + self._is_built = True + if not self.saver_def: + if self._builder is None: + self._builder = BaseSaverBuilder(self._write_version) + if self._var_list is None: + # pylint: disable=protected-access + self._var_list = variables._all_saveable_objects() + if not self._var_list: + if self._allow_empty: + self._is_empty = True + return + else: + raise ValueError("No variables to save") + self._is_empty = False + self.saver_def = self._builder.build( + self._var_list, + reshape=self._reshape, + sharded=self._sharded, + max_to_keep=self._max_to_keep, + keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, + name=self._name, + restore_sequentially=self._restore_sequentially) + elif self.saver_def and self._name: + # Since self._name is used as a name_scope by builder(), we are + # overloading the use of this field to represent the "import_scope" as + # well. + self.saver_def.filename_tensor_name = ops.prepend_name_scope( + self.saver_def.filename_tensor_name, self._name) + self.saver_def.save_tensor_name = ops.prepend_name_scope( + self.saver_def.save_tensor_name, self._name) + self.saver_def.restore_op_name = ops.prepend_name_scope( + self.saver_def.restore_op_name, self._name) - self._check_saver_def() - # Updates next checkpoint time. - self._next_checkpoint_time = ( - time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) - self._last_checkpoints = [] + self._check_saver_def() + # Updates next checkpoint time. + self._next_checkpoint_time = ( + time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) + self._last_checkpoints = [] - def _check_saver_def(self): - if not isinstance(self.saver_def, saver_pb2.SaverDef): - raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % - self.saver_def) - if not self.saver_def.save_tensor_name: - raise ValueError("saver_def must specify the save_tensor_name: %s" % - str(self.saver_def)) - if not self.saver_def.restore_op_name: - raise ValueError("saver_def must specify the restore_op_name: %s" % - str(self.saver_def)) + def _check_saver_def(self): + if not isinstance(self.saver_def, saver_pb2.SaverDef): + raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % + self.saver_def) + if not self.saver_def.save_tensor_name: + raise ValueError("saver_def must specify the save_tensor_name: %s" % + str(self.saver_def)) + if not self.saver_def.restore_op_name: + raise ValueError("saver_def must specify the restore_op_name: %s" % + str(self.saver_def)) - def _CheckpointFilename(self, p): - """Returns the checkpoint filename given a `(filename, time)` pair. + def _CheckpointFilename(self, p): + """Returns the checkpoint filename given a `(filename, time)` pair. Args: p: (filename, time) pair. @@ -1121,11 +1127,11 @@ class Saver(object): Returns: Checkpoint file name. """ - name, _ = p - return name + name, _ = p + return name - def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"): - """Returns the meta graph filename. + def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"): + """Returns the meta graph filename. Args: checkpoint_filename: Name of the checkpoint file. @@ -1134,17 +1140,17 @@ class Saver(object): Returns: MetaGraph file name. """ - # If the checkpoint_filename is sharded, the checkpoint_filename could - # be of format model.ckpt-step#-?????-of-shard#. For example, - # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. - basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) - meta_graph_filename = ".".join([basename, meta_graph_suffix]) - return meta_graph_filename + # If the checkpoint_filename is sharded, the checkpoint_filename could + # be of format model.ckpt-step#-?????-of-shard#. For example, + # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. + basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) + meta_graph_filename = ".".join([basename, meta_graph_suffix]) + return meta_graph_filename - def _MaybeDeleteOldCheckpoints(self, - latest_save_path, - meta_graph_suffix="meta"): - """Deletes old checkpoints if necessary. + def _MaybeDeleteOldCheckpoints(self, + latest_save_path, + meta_graph_suffix="meta"): + """Deletes old checkpoints if necessary. Always keep the last `max_to_keep` checkpoints. If `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint @@ -1156,55 +1162,55 @@ class Saver(object): latest_save_path: Name including path of checkpoint file to save. meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. """ - if not self.saver_def.max_to_keep: - return - # Remove first from list if the same name was used before. - for p in self._last_checkpoints: - if latest_save_path == self._CheckpointFilename(p): - self._last_checkpoints.remove(p) - # Append new path to list - self._last_checkpoints.append((latest_save_path, time.time())) - # If more than max_to_keep, remove oldest. - if len(self._last_checkpoints) > self.saver_def.max_to_keep: - p = self._last_checkpoints.pop(0) - # Do not delete the file if we keep_checkpoint_every_n_hours is set and we - # have reached N hours of training. - should_keep = p[1] > self._next_checkpoint_time - if should_keep: - self._next_checkpoint_time += ( - self.saver_def.keep_checkpoint_every_n_hours * 3600) - return + if not self.saver_def.max_to_keep: + return + # Remove first from list if the same name was used before. + for p in self._last_checkpoints: + if latest_save_path == self._CheckpointFilename(p): + self._last_checkpoints.remove(p) + # Append new path to list + self._last_checkpoints.append((latest_save_path, time.time())) + # If more than max_to_keep, remove oldest. + if len(self._last_checkpoints) > self.saver_def.max_to_keep: + p = self._last_checkpoints.pop(0) + # Do not delete the file if we keep_checkpoint_every_n_hours is set and we + # have reached N hours of training. + should_keep = p[1] > self._next_checkpoint_time + if should_keep: + self._next_checkpoint_time += ( + self.saver_def.keep_checkpoint_every_n_hours * 3600) + return - # Otherwise delete the files. - try: - checkpoint_prefix = self._CheckpointFilename(p) - self._delete_file_if_exists( - self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix)) - if self.saver_def.version == saver_pb2.SaverDef.V2: - # V2 has a metadata file and some data files. - self._delete_file_if_exists(checkpoint_prefix + ".index") - self._delete_file_if_exists(checkpoint_prefix + - ".data-?????-of-?????") - else: - # V1, Legacy. Exact match on the data file. - self._delete_file_if_exists(checkpoint_prefix) - except Exception as e: # pylint: disable=broad-except - logging.warning("Ignoring: %s", str(e)) + # Otherwise delete the files. + try: + checkpoint_prefix = self._CheckpointFilename(p) + self._delete_file_if_exists( + self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix)) + if self.saver_def.version == saver_pb2.SaverDef.V2: + # V2 has a metadata file and some data files. + self._delete_file_if_exists(checkpoint_prefix + ".index") + self._delete_file_if_exists(checkpoint_prefix + + ".data-?????-of-?????") + else: + # V1, Legacy. Exact match on the data file. + self._delete_file_if_exists(checkpoint_prefix) + except Exception as e: # pylint: disable=broad-except + logging.warning("Ignoring: %s", str(e)) - def _delete_file_if_exists(self, filespec): - for pathname in file_io.get_matching_files(filespec): - file_io.delete_file(pathname) + def _delete_file_if_exists(self, filespec): + for pathname in file_io.get_matching_files(filespec): + file_io.delete_file(pathname) - def as_saver_def(self): - """Generates a `SaverDef` representation of this saver. + def as_saver_def(self): + """Generates a `SaverDef` representation of this saver. Returns: A `SaverDef` proto. """ - return self.saver_def + return self.saver_def - def to_proto(self, export_scope=None): - """Converts this `Saver` to a `SaverDef` protocol buffer. + def to_proto(self, export_scope=None): + """Converts this `Saver` to a `SaverDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. @@ -1212,15 +1218,15 @@ class Saver(object): Returns: A `SaverDef` protocol buffer. """ - if (export_scope is None or - self._name.startswith(export_scope)): - return self.saver_def - else: - return None + if (export_scope is None or + self._name.startswith(export_scope)): + return self.saver_def + else: + return None - @staticmethod - def from_proto(saver_def, import_scope=None): - """Returns a `Saver` object created from `saver_def`. + @staticmethod + def from_proto(saver_def, import_scope=None): + """Returns a `Saver` object created from `saver_def`. Args: saver_def: a `SaveDef` protocol buffer. @@ -1229,21 +1235,21 @@ class Saver(object): Returns: A `Saver` built from saver_def. """ - return Saver(saver_def=saver_def, name=import_scope) + return Saver(saver_def=saver_def, name=import_scope) - @property - def last_checkpoints(self): - """List of not-yet-deleted checkpoint filenames. + @property + def last_checkpoints(self): + """List of not-yet-deleted checkpoint filenames. You can pass any of the returned values to `restore()`. Returns: A list of checkpoint filenames, sorted from oldest to newest. """ - return list(self._CheckpointFilename(p) for p in self._last_checkpoints) + return list(self._CheckpointFilename(p) for p in self._last_checkpoints) - def set_last_checkpoints(self, last_checkpoints): - """DEPRECATED: Use set_last_checkpoints_with_time. + def set_last_checkpoints(self, last_checkpoints): + """DEPRECATED: Use set_last_checkpoints_with_time. Sets the list of old checkpoint filenames. @@ -1253,14 +1259,14 @@ class Saver(object): Raises: AssertionError: If last_checkpoints is not a list. """ - assert isinstance(last_checkpoints, list) - # We use a timestamp of +inf so that this checkpoint will never be - # deleted. This is both safe and backwards compatible to a previous - # version of the code which used s[1] as the "timestamp". - self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] + assert isinstance(last_checkpoints, list) + # We use a timestamp of +inf so that this checkpoint will never be + # deleted. This is both safe and backwards compatible to a previous + # version of the code which used s[1] as the "timestamp". + self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] - def set_last_checkpoints_with_time(self, last_checkpoints_with_time): - """Sets the list of old checkpoint filenames and timestamps. + def set_last_checkpoints_with_time(self, last_checkpoints_with_time): + """Sets the list of old checkpoint filenames and timestamps. Args: last_checkpoints_with_time: A list of tuples of checkpoint filenames and @@ -1269,11 +1275,11 @@ class Saver(object): Raises: AssertionError: If last_checkpoints_with_time is not a list. """ - assert isinstance(last_checkpoints_with_time, list) - self._last_checkpoints = last_checkpoints_with_time + assert isinstance(last_checkpoints_with_time, list) + self._last_checkpoints = last_checkpoints_with_time - def recover_last_checkpoints(self, checkpoint_paths): - """Recovers the internal saver state after a crash. + def recover_last_checkpoints(self, checkpoint_paths): + """Recovers the internal saver state after a crash. This method is useful for recovering the "self._last_checkpoints" state. @@ -1283,18 +1289,18 @@ class Saver(object): Args: checkpoint_paths: a list of checkpoint paths. """ - mtimes = get_checkpoint_mtimes(checkpoint_paths) - self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) + mtimes = get_checkpoint_mtimes(checkpoint_paths) + self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) - def save(self, - sess, - save_path, - global_step=None, - latest_filename=None, - meta_graph_suffix="meta", - write_meta_graph=True, - write_state=True): - """Saves variables. + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True): + """Saves variables. This method runs the ops added by the constructor for saving variables. It requires a session in which the graph was launched. The variables to @@ -1333,75 +1339,75 @@ class Saver(object): collides with `save_path`. RuntimeError: If save and restore ops weren't built. """ - if not self._is_built: - raise RuntimeError( - "`build()` should be called before save if defer_build==True") - if latest_filename is None: - latest_filename = "checkpoint" - if self._write_version != saver_pb2.SaverDef.V2: - logging.warning("*******************************************************") - logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") - logging.warning("Consider switching to the more efficient V2 format:") - logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") - logging.warning("now on by default.") - logging.warning("*******************************************************") + if not self._is_built: + raise RuntimeError( + "`build()` should be called before save if defer_build==True") + if latest_filename is None: + latest_filename = "checkpoint" + if self._write_version != saver_pb2.SaverDef.V2: + logging.warning("*******************************************************") + logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") + logging.warning("Consider switching to the more efficient V2 format:") + logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") + logging.warning("now on by default.") + logging.warning("*******************************************************") - if os.path.split(latest_filename)[0]: - raise ValueError("'latest_filename' must not contain path components") + if os.path.split(latest_filename)[0]: + raise ValueError("'latest_filename' must not contain path components") - if global_step is not None: - if not isinstance(global_step, compat.integral_types): - global_step = training_util.global_step(sess, global_step) - checkpoint_file = "%s-%d" % (save_path, global_step) - if self._pad_step_number: - # Zero-pads the step numbers, so that they are sorted when listed. - checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) - else: - checkpoint_file = save_path - if os.path.basename( - save_path) == latest_filename and not self.saver_def.sharded: - # Guard against collision between data file and checkpoint state file. - raise ValueError( - "'latest_filename' collides with 'save_path': '%s' and '%s'" % - (latest_filename, save_path)) + if global_step is not None: + if not isinstance(global_step, compat.integral_types): + global_step = training_util.global_step(sess, global_step) + checkpoint_file = "%s-%d" % (save_path, global_step) + if self._pad_step_number: + # Zero-pads the step numbers, so that they are sorted when listed. + checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) + else: + checkpoint_file = save_path + if os.path.basename( + save_path) == latest_filename and not self.saver_def.sharded: + # Guard against collision between data file and checkpoint state file. + raise ValueError( + "'latest_filename' collides with 'save_path': '%s' and '%s'" % + (latest_filename, save_path)) - if not gfile.IsDirectory(os.path.dirname(save_path)): - raise ValueError( - "Parent directory of {} doesn't exist, can't save.".format(save_path)) + if not gfile.IsDirectory(os.path.dirname(save_path)): + raise ValueError( + "Parent directory of {} doesn't exist, can't save.".format(save_path)) - save_path = os.path.dirname(save_path) - if not isinstance(sess, session.SessionInterface): - raise TypeError("'sess' must be a Session; %s" % sess) + save_path = os.path.dirname(save_path) + if not isinstance(sess, session.SessionInterface): + raise TypeError("'sess' must be a Session; %s" % sess) - if not self._is_empty: - model_checkpoint_path = sess.run( - self.saver_def.save_tensor_name, - {self.saver_def.filename_tensor_name: checkpoint_file}) - model_checkpoint_path = compat.as_str(model_checkpoint_path) - if write_state: - self._MaybeDeleteOldCheckpoints( - model_checkpoint_path, meta_graph_suffix=meta_graph_suffix) - update_checkpoint_state(save_path, model_checkpoint_path, - self.last_checkpoints, latest_filename) + if not self._is_empty: + model_checkpoint_path = sess.run( + self.saver_def.save_tensor_name, + {self.saver_def.filename_tensor_name: checkpoint_file}) + model_checkpoint_path = compat.as_str(model_checkpoint_path) + if write_state: + self._MaybeDeleteOldCheckpoints( + model_checkpoint_path, meta_graph_suffix=meta_graph_suffix) + update_checkpoint_state(save_path, model_checkpoint_path, + self.last_checkpoints, latest_filename) - if write_meta_graph: - meta_graph_filename = self._MetaGraphFilename( - checkpoint_file, meta_graph_suffix=meta_graph_suffix) - with sess.graph.as_default(): - self.export_meta_graph(meta_graph_filename) + if write_meta_graph: + meta_graph_filename = self._MetaGraphFilename( + checkpoint_file, meta_graph_suffix=meta_graph_suffix) + with sess.graph.as_default(): + self.export_meta_graph(meta_graph_filename) - if self._is_empty: - return None - else: - return model_checkpoint_path + if self._is_empty: + return None + else: + return model_checkpoint_path - def export_meta_graph(self, - filename=None, - collection_list=None, - as_text=False, - export_scope=None, - clear_devices=False): - """Writes `MetaGraphDef` to save_path/filename. + def export_meta_graph(self, + filename=None, + collection_list=None, + as_text=False, + export_scope=None, + clear_devices=False): + """Writes `MetaGraphDef` to save_path/filename. Args: filename: Optional meta_graph filename including the path. @@ -1414,17 +1420,17 @@ class Saver(object): Returns: A `MetaGraphDef` proto. """ - return export_meta_graph( - filename=filename, - graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), - saver_def=self.saver_def, - collection_list=collection_list, - as_text=as_text, - export_scope=export_scope, - clear_devices=clear_devices) + return export_meta_graph( + filename=filename, + graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), + saver_def=self.saver_def, + collection_list=collection_list, + as_text=as_text, + export_scope=export_scope, + clear_devices=clear_devices) - def restore(self, sess, save_path): - """Restores previously saved variables. + def restore(self, sess, save_path): + """Restores previously saved variables. This method runs the ops added by the constructor for restoring variables. It requires a session in which the graph was launched. The variables to @@ -1438,27 +1444,27 @@ class Saver(object): sess: A `Session` to use to restore the parameters. save_path: Path where parameters were previously saved. """ - if self._is_empty: - return - logging.info("Restoring parameters from %s", save_path) - sess.run(self.saver_def.restore_op_name, - {self.saver_def.filename_tensor_name: save_path}) + if self._is_empty: + return + logging.info("Restoring parameters from %s", save_path) + sess.run(self.saver_def.restore_op_name, + {self.saver_def.filename_tensor_name: save_path}) - @staticmethod - def _add_collection_def(meta_graph_def, key, export_scope=None): - """Adds a collection to MetaGraphDef protocol buffer. + @staticmethod + def _add_collection_def(meta_graph_def, key, export_scope=None): + """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. export_scope: Optional `string`. Name scope to remove. """ - meta_graph.add_collection_def(meta_graph_def, key, - export_scope=export_scope) + meta_graph.add_collection_def(meta_graph_def, key, + export_scope=export_scope) def _prefix_to_checkpoint_path(prefix, format_version): - """Returns the pathname of a checkpoint file, given the checkpoint prefix. + """Returns the pathname of a checkpoint file, given the checkpoint prefix. For V1 checkpoint, simply returns the prefix itself (the data file). For V2, returns the pathname to the index file. @@ -1471,13 +1477,13 @@ def _prefix_to_checkpoint_path(prefix, format_version): The pathname of a checkpoint file, taking into account the checkpoint format version. """ - if format_version == saver_pb2.SaverDef.V2: - return prefix + ".index" # The index file identifies a checkpoint. - return prefix # Just the data file. + if format_version == saver_pb2.SaverDef.V2: + return prefix + ".index" # The index file identifies a checkpoint. + return prefix # Just the data file. def latest_checkpoint(checkpoint_dir, latest_filename=None): - """Finds the filename of latest saved checkpoint file. + """Finds the filename of latest saved checkpoint file. Args: checkpoint_dir: Directory where the variables were saved. @@ -1488,26 +1494,26 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ - # Pick the latest checkpoint based on checkpoint state. - ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) - if ckpt and ckpt.model_checkpoint_path: - # Look for either a V2 path or a V1 path, with priority for V2. - v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V2) - v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V1) - if file_io.get_matching_files(v2_path) or file_io.get_matching_files( - v1_path): - return ckpt.model_checkpoint_path - else: - logging.error("Couldn't match files for checkpoint %s", - ckpt.model_checkpoint_path) - return None + # Pick the latest checkpoint based on checkpoint state. + ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) + if ckpt and ckpt.model_checkpoint_path: + # Look for either a V2 path or a V1 path, with priority for V2. + v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V2) + v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V1) + if file_io.get_matching_files(v2_path) or file_io.get_matching_files( + v1_path): + return ckpt.model_checkpoint_path + else: + logging.error("Couldn't match files for checkpoint %s", + ckpt.model_checkpoint_path) + return None def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs): - """Recreates a Graph saved in a `MetaGraphDef` proto. + """Recreates a Graph saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , @@ -1572,26 +1578,26 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, A None value is returned if no variables exist in the `MetaGraphDef` (i.e., there are no variables to restore). """ - if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): - meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) - else: - meta_graph_def = meta_graph_or_file - - meta_graph.import_scoped_meta_graph(meta_graph_def, - clear_devices=clear_devices, - import_scope=import_scope, - **kwargs) - if meta_graph_def.HasField("saver_def"): - return Saver(saver_def=meta_graph_def.saver_def, name=import_scope) - else: - if variables._all_saveable_objects(): # pylint: disable=protected-access - # Return the default saver instance for all graph variables. - return Saver() + if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): + meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) else: - # If no graph variables exist, then a Saver cannot be constructed. - logging.info("Saver not created because there are no variables in the" - " graph to restore") - return None + meta_graph_def = meta_graph_or_file + + meta_graph.import_scoped_meta_graph(meta_graph_def, + clear_devices=clear_devices, + import_scope=import_scope, + **kwargs) + if meta_graph_def.HasField("saver_def"): + return Saver(saver_def=meta_graph_def.saver_def, name=import_scope) + else: + if variables._all_saveable_objects(): # pylint: disable=protected-access + # Return the default saver instance for all graph variables. + return Saver() + else: + # If no graph variables exist, then a Saver cannot be constructed. + logging.info("Saver not created because there are no variables in the" + " graph to restore") + return None def export_meta_graph(filename=None, @@ -1604,7 +1610,7 @@ def export_meta_graph(filename=None, export_scope=None, clear_devices=False, **kwargs): - """Returns `MetaGraphDef` proto. Optionally writes it to filename. + """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported @@ -1634,22 +1640,22 @@ def export_meta_graph(filename=None, Raises: ValueError: When the `GraphDef` is larger than 2GB. """ - meta_graph_def, _ = meta_graph.export_scoped_meta_graph( - filename=filename, - meta_info_def=meta_info_def, - graph_def=graph_def, - saver_def=saver_def, - collection_list=collection_list, - as_text=as_text, - graph=graph, - export_scope=export_scope, - clear_devices=clear_devices, - **kwargs) - return meta_graph_def + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + filename=filename, + meta_info_def=meta_info_def, + graph_def=graph_def, + saver_def=saver_def, + collection_list=collection_list, + as_text=as_text, + graph=graph, + export_scope=export_scope, + clear_devices=clear_devices, + **kwargs) + return meta_graph_def def checkpoint_exists(checkpoint_prefix): - """Checks whether a V1 or V2 checkpoint exists with the specified prefix. + """Checks whether a V1 or V2 checkpoint exists with the specified prefix. This is the recommended way to check if a checkpoint exists, since it takes into account the naming difference between V1 and V2 formats. @@ -1662,18 +1668,18 @@ def checkpoint_exists(checkpoint_prefix): Returns: A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. """ - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if file_io.get_matching_files(pathname): - return True - elif file_io.get_matching_files(checkpoint_prefix): - return True - else: - return False + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if file_io.get_matching_files(pathname): + return True + elif file_io.get_matching_files(checkpoint_prefix): + return True + else: + return False def get_checkpoint_mtimes(checkpoint_prefixes): - """Returns the mtimes (modification timestamps) of the checkpoints. + """Returns the mtimes (modification timestamps) of the checkpoints. Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files exist, collect their mtime. Both V2 and V1 checkpoints are considered, in @@ -1689,25 +1695,25 @@ def get_checkpoint_mtimes(checkpoint_prefixes): Returns: A list of mtimes (in microseconds) of the found checkpoints. """ - mtimes = [] + mtimes = [] - def match_maybe_append(pathname): - fnames = file_io.get_matching_files(pathname) - if fnames: - mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) - return True - return False + def match_maybe_append(pathname): + fnames = file_io.get_matching_files(pathname) + if fnames: + mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) + return True + return False - for checkpoint_prefix in checkpoint_prefixes: - # Tries V2's metadata file first. - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if match_maybe_append(pathname): - continue - # Otherwise, tries V1, where the prefix is the complete pathname. - match_maybe_append(checkpoint_prefix) + for checkpoint_prefix in checkpoint_prefixes: + # Tries V2's metadata file first. + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if match_maybe_append(pathname): + continue + # Otherwise, tries V1, where the prefix is the complete pathname. + match_maybe_append(checkpoint_prefix) - return mtimes + return mtimes ops.register_proto_function(