Make the initial tf.train.Checkpoint.restore() read Tensors in a batch
Should be faster when reading from distributed file systems. Does not affect cases where restore-on-create is necessary, but as long as variable objects have been created and tracked before restore() their reads should be batched together. PiperOrigin-RevId: 227911381
This commit is contained in:
parent
962d8821eb
commit
efe565bc09
@ -176,25 +176,13 @@ class _CheckpointRestoreCoordinator(object):
|
||||
raise AssertionError(
|
||||
("Saveable keys changed when validating. Got back %s, was "
|
||||
"expecting %s") % (tensor_saveables.keys(), validated_names))
|
||||
for saveable in validated_saveables:
|
||||
if saveable.device:
|
||||
device = saveable_object_util.set_cpu0(saveable.device)
|
||||
else:
|
||||
device = None
|
||||
with ops.device(device):
|
||||
tensors = []
|
||||
for spec in saveable.specs:
|
||||
tensors.append(
|
||||
io_ops.restore_v2(
|
||||
self.save_path_tensor,
|
||||
[spec.name],
|
||||
[spec.slice_spec],
|
||||
[spec.dtype])[0])
|
||||
restore_op = saveable.restore(tensors, restored_shapes=None)
|
||||
if not context.executing_eagerly():
|
||||
new_restore_ops = functional_saver.restore_from_saveable_objects(
|
||||
self.save_path_tensor, validated_saveables)
|
||||
if not context.executing_eagerly():
|
||||
restore_ops.extend(new_restore_ops)
|
||||
for saveable, restore_op in zip(validated_saveables, new_restore_ops):
|
||||
assert saveable.name not in self.restore_ops_by_name
|
||||
self.restore_ops_by_name[saveable.name] = restore_op
|
||||
restore_ops.append(restore_op)
|
||||
return restore_ops
|
||||
|
||||
|
||||
|
@ -107,25 +107,32 @@ class Saver(object):
|
||||
A scalar string Tensor containing `file_prefix` with control dependencies
|
||||
on the restore ops.
|
||||
"""
|
||||
restore_specs = []
|
||||
tensor_structure = []
|
||||
for saveable in self._saveable_objects:
|
||||
saveable_tensor_structure = []
|
||||
tensor_structure.append(saveable_tensor_structure)
|
||||
for spec in saveable.specs:
|
||||
saveable_tensor_structure.append(spec.name)
|
||||
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
|
||||
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
|
||||
with ops.device("cpu:0"):
|
||||
restored_tensors = io_ops.restore_v2(
|
||||
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
|
||||
structured_restored_tensors = nest.pack_sequence_as(
|
||||
tensor_structure, restored_tensors)
|
||||
restore_ops = []
|
||||
for saveable, restored_tensors in zip(self._saveable_objects,
|
||||
structured_restored_tensors):
|
||||
restore_ops.append(saveable.restore(restored_tensors,
|
||||
restored_shapes=None))
|
||||
restore_ops = restore_from_saveable_objects(
|
||||
file_prefix, self._saveable_objects)
|
||||
with ops.device("cpu:0"):
|
||||
with ops.control_dependencies(restore_ops):
|
||||
return array_ops.identity(file_prefix)
|
||||
|
||||
|
||||
def restore_from_saveable_objects(file_prefix, saveable_objects):
|
||||
"""Reads from a checkpoint and returns restore ops for `saveable_objects`s."""
|
||||
restore_specs = []
|
||||
tensor_structure = []
|
||||
for saveable in saveable_objects:
|
||||
saveable_tensor_structure = []
|
||||
tensor_structure.append(saveable_tensor_structure)
|
||||
for spec in saveable.specs:
|
||||
saveable_tensor_structure.append(spec.name)
|
||||
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
|
||||
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
|
||||
with ops.device("cpu:0"):
|
||||
restored_tensors = io_ops.restore_v2(
|
||||
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
|
||||
structured_restored_tensors = nest.pack_sequence_as(
|
||||
tensor_structure, restored_tensors)
|
||||
restore_ops = []
|
||||
for saveable, restored_tensors in zip(saveable_objects,
|
||||
structured_restored_tensors):
|
||||
restore_ops.append(saveable.restore(restored_tensors,
|
||||
restored_shapes=None))
|
||||
return restore_ops
|
||||
|
Loading…
x
Reference in New Issue
Block a user