Correctly set the experimental_io_device when restoring variable from a checkpoint.
PiperOrigin-RevId: 320222381 Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
This commit is contained in:
parent
14b2d686d6
commit
b8694e39d8
|
@ -293,9 +293,10 @@ class CheckpointPosition(object):
|
|||
checkpoint_key = serialized_tensor.checkpoint_key
|
||||
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
||||
base_type = dtype.base_dtype
|
||||
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
|
||||
with ops.init_scope():
|
||||
with ops.device("/cpu:0"):
|
||||
# Run the restore itself on the CPU.
|
||||
with ops.device(io_device):
|
||||
# Run the restore itself on the io_device(CPU or specified).
|
||||
value, = io_ops.restore_v2(
|
||||
prefix=self._checkpoint.save_path_tensor,
|
||||
tensor_names=[checkpoint_key],
|
||||
|
|
Loading…
Reference in New Issue