Pull request for fixing warm-starting device placement (#17312)
* Update checkpoint_utils.py Fix device allocation bug for warm-starting op * Update checkpoint_utils_test.py Fix test
This commit is contained in:
parent
7f53659bc6
commit
3ba1f72f88
tensorflow/python/training
@ -289,7 +289,11 @@ def _set_checkpoint_initializer(variable,
|
||||
name: Name of the operation.
|
||||
"""
|
||||
base_type = variable.dtype.base_dtype
|
||||
with ops.colocate_with(variable):
|
||||
# Do not colocate with variable since RestoreV2 op only runs on CPU and
|
||||
# colocation will force variable (and other ops that colocate with variable)
|
||||
# to be on CPU as well. It is okay to place the variable's initializer op on
|
||||
# CPU since it will only be run once at the start.
|
||||
with ops.device(variable.device), ops.device("/cpu:0"):
|
||||
restore_op = io_ops.restore_v2(
|
||||
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||
variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access
|
||||
|
@ -206,7 +206,9 @@ class CheckpointsTest(test.TestCase):
|
||||
|
||||
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
|
||||
{"useful_scope/": "useful_scope/"})
|
||||
self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps")
|
||||
# initializer runs on the same task but always on CPU.
|
||||
self.assertEqual(my4._initializer_op.op.inputs[1].device,
|
||||
"/job:ps/device:CPU:0")
|
||||
|
||||
def testInitFromRootCheckpoint(self):
|
||||
checkpoint_dir = self.get_temp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user