Correctly set the experimental_io_device when restoring variable from a checkpoint.
PiperOrigin-RevId: 320222381 Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
This commit is contained in:
parent
14b2d686d6
commit
b8694e39d8
|
@ -293,9 +293,10 @@ class CheckpointPosition(object):
|
||||||
checkpoint_key = serialized_tensor.checkpoint_key
|
checkpoint_key = serialized_tensor.checkpoint_key
|
||||||
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
||||||
base_type = dtype.base_dtype
|
base_type = dtype.base_dtype
|
||||||
|
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
with ops.device("/cpu:0"):
|
with ops.device(io_device):
|
||||||
# Run the restore itself on the CPU.
|
# Run the restore itself on the io_device(CPU or specified).
|
||||||
value, = io_ops.restore_v2(
|
value, = io_ops.restore_v2(
|
||||||
prefix=self._checkpoint.save_path_tensor,
|
prefix=self._checkpoint.save_path_tensor,
|
||||||
tensor_names=[checkpoint_key],
|
tensor_names=[checkpoint_key],
|
||||||
|
|
Loading…
Reference in New Issue