diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 9337adbf88a..47fbdddd4d9 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -293,9 +293,10 @@ class CheckpointPosition(object): checkpoint_key = serialized_tensor.checkpoint_key dtype = self._checkpoint.dtype_map[checkpoint_key] base_type = dtype.base_dtype + io_device = self._checkpoint.options.experimental_io_device or "cpu:0" with ops.init_scope(): - with ops.device("/cpu:0"): - # Run the restore itself on the CPU. + with ops.device(io_device): + # Run the restore itself on the io_device(CPU or specified). value, = io_ops.restore_v2( prefix=self._checkpoint.save_path_tensor, tensor_names=[checkpoint_key],