From b8694e39d86d1c08e4dd240536dbf4d90c7eac32 Mon Sep 17 00:00:00 2001
From: Ken Franko <kfranko@google.com>
Date: Wed, 8 Jul 2020 10:58:30 -0700
Subject: [PATCH] Correctly set the experimental_io_device when restoring
 variable from a checkpoint.

PiperOrigin-RevId: 320222381
Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
---
 tensorflow/python/training/tracking/base.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index 9337adbf88a..47fbdddd4d9 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -293,9 +293,10 @@ class CheckpointPosition(object):
       checkpoint_key = serialized_tensor.checkpoint_key
       dtype = self._checkpoint.dtype_map[checkpoint_key]
       base_type = dtype.base_dtype
+      io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
       with ops.init_scope():
-        with ops.device("/cpu:0"):
-          # Run the restore itself on the CPU.
+        with ops.device(io_device):
+          # Run the restore itself on the io_device(CPU or specified).
           value, = io_ops.restore_v2(
               prefix=self._checkpoint.save_path_tensor,
               tensor_names=[checkpoint_key],