From efe565bc0981e80a52a97f3961cfba3e87023b42 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 4 Jan 2019 14:08:26 -0800 Subject: [PATCH] Make the initial tf.train.Checkpoint.restore() read Tensors in a batch Should be faster when reading from distributed file systems. Does not affect cases where restore-on-create is necessary, but as long as variable objects have been created and tracked before restore() their reads should be batched together. PiperOrigin-RevId: 227911381 --- .../python/training/checkpointable/util.py | 22 +++------ .../training/saving/functional_saver.py | 45 +++++++++++-------- 2 files changed, 31 insertions(+), 36 deletions(-) diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index c890a7f4408..a45263f5c6b 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -176,25 +176,13 @@ class _CheckpointRestoreCoordinator(object): raise AssertionError( ("Saveable keys changed when validating. Got back %s, was " "expecting %s") % (tensor_saveables.keys(), validated_names)) - for saveable in validated_saveables: - if saveable.device: - device = saveable_object_util.set_cpu0(saveable.device) - else: - device = None - with ops.device(device): - tensors = [] - for spec in saveable.specs: - tensors.append( - io_ops.restore_v2( - self.save_path_tensor, - [spec.name], - [spec.slice_spec], - [spec.dtype])[0]) - restore_op = saveable.restore(tensors, restored_shapes=None) - if not context.executing_eagerly(): + new_restore_ops = functional_saver.restore_from_saveable_objects( + self.save_path_tensor, validated_saveables) + if not context.executing_eagerly(): + restore_ops.extend(new_restore_ops) + for saveable, restore_op in zip(validated_saveables, new_restore_ops): assert saveable.name not in self.restore_ops_by_name self.restore_ops_by_name[saveable.name] = restore_op - restore_ops.append(restore_op) return restore_ops diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py index 51f618ddd32..4ff2742c2f1 100644 --- a/tensorflow/python/training/saving/functional_saver.py +++ b/tensorflow/python/training/saving/functional_saver.py @@ -107,25 +107,32 @@ class Saver(object): A scalar string Tensor containing `file_prefix` with control dependencies on the restore ops. """ - restore_specs = [] - tensor_structure = [] - for saveable in self._saveable_objects: - saveable_tensor_structure = [] - tensor_structure.append(saveable_tensor_structure) - for spec in saveable.specs: - saveable_tensor_structure.append(spec.name) - restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) - tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) - with ops.device("cpu:0"): - restored_tensors = io_ops.restore_v2( - file_prefix, tensor_names, tensor_slices, tensor_dtypes) - structured_restored_tensors = nest.pack_sequence_as( - tensor_structure, restored_tensors) - restore_ops = [] - for saveable, restored_tensors in zip(self._saveable_objects, - structured_restored_tensors): - restore_ops.append(saveable.restore(restored_tensors, - restored_shapes=None)) + restore_ops = restore_from_saveable_objects( + file_prefix, self._saveable_objects) with ops.device("cpu:0"): with ops.control_dependencies(restore_ops): return array_ops.identity(file_prefix) + + +def restore_from_saveable_objects(file_prefix, saveable_objects): + """Reads from a checkpoint and returns restore ops for `saveable_objects`s.""" + restore_specs = [] + tensor_structure = [] + for saveable in saveable_objects: + saveable_tensor_structure = [] + tensor_structure.append(saveable_tensor_structure) + for spec in saveable.specs: + saveable_tensor_structure.append(spec.name) + restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) + tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) + with ops.device("cpu:0"): + restored_tensors = io_ops.restore_v2( + file_prefix, tensor_names, tensor_slices, tensor_dtypes) + structured_restored_tensors = nest.pack_sequence_as( + tensor_structure, restored_tensors) + restore_ops = [] + for saveable, restored_tensors in zip(saveable_objects, + structured_restored_tensors): + restore_ops.append(saveable.restore(restored_tensors, + restored_shapes=None)) + return restore_ops