From b8694e39d86d1c08e4dd240536dbf4d90c7eac32 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Wed, 8 Jul 2020 10:58:30 -0700 Subject: [PATCH] Correctly set the experimental_io_device when restoring variable from a checkpoint. PiperOrigin-RevId: 320222381 Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227 --- tensorflow/python/training/tracking/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 9337adbf88a..47fbdddd4d9 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -293,9 +293,10 @@ class CheckpointPosition(object): checkpoint_key = serialized_tensor.checkpoint_key dtype = self._checkpoint.dtype_map[checkpoint_key] base_type = dtype.base_dtype + io_device = self._checkpoint.options.experimental_io_device or "cpu:0" with ops.init_scope(): - with ops.device("/cpu:0"): - # Run the restore itself on the CPU. + with ops.device(io_device): + # Run the restore itself on the io_device(CPU or specified). value, = io_ops.restore_v2( prefix=self._checkpoint.save_path_tensor, tensor_names=[checkpoint_key],