Pull request for fixing warm-starting device placement ()

* Update checkpoint_utils.py

Fix device allocation bug for warm-starting op

* Update checkpoint_utils_test.py

Fix test
This commit is contained in:
vihanjain 2018-02-27 16:05:26 -08:00 committed by Gunhan Gulsoy
parent 7f53659bc6
commit 3ba1f72f88
2 changed files with 8 additions and 2 deletions

View File

@ -289,7 +289,11 @@ def _set_checkpoint_initializer(variable,
name: Name of the operation.
"""
base_type = variable.dtype.base_dtype
with ops.colocate_with(variable):
# Do not colocate with variable since RestoreV2 op only runs on CPU and
# colocation will force variable (and other ops that colocate with variable)
# to be on CPU as well. It is okay to place the variable's initializer op on
# CPU since it will only be run once at the start.
with ops.device(variable.device), ops.device("/cpu:0"):
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access

View File

@ -206,7 +206,9 @@ class CheckpointsTest(test.TestCase):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps")
# initializer runs on the same task but always on CPU.
self.assertEqual(my4._initializer_op.op.inputs[1].device,
"/job:ps/device:CPU:0")
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()