Correctly set the experimental_io_device when restoring variable from a checkpoint.

PiperOrigin-RevId: 320222381
Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
This commit is contained in:
Ken Franko 2020-07-08 10:58:30 -07:00
parent 14b2d686d6
commit b8694e39d8
1 changed files with 3 additions and 2 deletions

View File

@ -293,9 +293,10 @@ class CheckpointPosition(object):
checkpoint_key = serialized_tensor.checkpoint_key
dtype = self._checkpoint.dtype_map[checkpoint_key]
base_type = dtype.base_dtype
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
with ops.init_scope():
with ops.device("/cpu:0"):
# Run the restore itself on the CPU.
with ops.device(io_device):
# Run the restore itself on the io_device(CPU or specified).
value, = io_ops.restore_v2(
prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],