Make the initial tf.train.Checkpoint.restore() read Tensors in a batch

Should be faster when reading from distributed file systems. Does not affect cases where restore-on-create is necessary, but as long as variable objects have been created and tracked before restore() their reads should be batched together.

PiperOrigin-RevId: 227911381
This commit is contained in:
Allen Lavoie 2019-01-04 14:08:26 -08:00 committed by TensorFlower Gardener
parent 962d8821eb
commit efe565bc09
2 changed files with 31 additions and 36 deletions

View File

@ -176,25 +176,13 @@ class _CheckpointRestoreCoordinator(object):
raise AssertionError(
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (tensor_saveables.keys(), validated_names))
for saveable in validated_saveables:
if saveable.device:
device = saveable_object_util.set_cpu0(saveable.device)
else:
device = None
with ops.device(device):
tensors = []
for spec in saveable.specs:
tensors.append(
io_ops.restore_v2(
self.save_path_tensor,
[spec.name],
[spec.slice_spec],
[spec.dtype])[0])
restore_op = saveable.restore(tensors, restored_shapes=None)
if not context.executing_eagerly():
new_restore_ops = functional_saver.restore_from_saveable_objects(
self.save_path_tensor, validated_saveables)
if not context.executing_eagerly():
restore_ops.extend(new_restore_ops)
for saveable, restore_op in zip(validated_saveables, new_restore_ops):
assert saveable.name not in self.restore_ops_by_name
self.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
return restore_ops

View File

@ -107,25 +107,32 @@ class Saver(object):
A scalar string Tensor containing `file_prefix` with control dependencies
on the restore ops.
"""
restore_specs = []
tensor_structure = []
for saveable in self._saveable_objects:
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
with ops.device("cpu:0"):
restored_tensors = io_ops.restore_v2(
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
restore_ops = []
for saveable, restored_tensors in zip(self._saveable_objects,
structured_restored_tensors):
restore_ops.append(saveable.restore(restored_tensors,
restored_shapes=None))
restore_ops = restore_from_saveable_objects(
file_prefix, self._saveable_objects)
with ops.device("cpu:0"):
with ops.control_dependencies(restore_ops):
return array_ops.identity(file_prefix)
def restore_from_saveable_objects(file_prefix, saveable_objects):
"""Reads from a checkpoint and returns restore ops for `saveable_objects`s."""
restore_specs = []
tensor_structure = []
for saveable in saveable_objects:
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
with ops.device("cpu:0"):
restored_tensors = io_ops.restore_v2(
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
restore_ops = []
for saveable, restored_tensors in zip(saveable_objects,
structured_restored_tensors):
restore_ops.append(saveable.restore(restored_tensors,
restored_shapes=None))
return restore_ops